@@ -257,29 +257,36 @@ def sample(
257257 num_chains : Optional [int ] = None ,
258258 init_strategy : Optional [str ] = None ,
259259 init_strategy_parameters : Optional [Dict [str , Any ]] = None ,
260- init_strategy_num_candidates : Optional [int ] = None ,
261- mcmc_parameters : Optional [Dict ] = None ,
262- mcmc_method : Optional [str ] = None ,
263- sample_with : Optional [str ] = None ,
264260 num_workers : Optional [int ] = None ,
265261 mp_context : Optional [str ] = None ,
266262 show_progress_bars : bool = True ,
267263 ) -> Tensor :
268- r"""Return samples from posterior distribution $p(\theta|x)$ with MCMC.
269-
270- Check the `__init__()` method for a description of all arguments as well as
271- their default values.
264+ r"""Draw samples from the approximate posterior distribution $p(\theta|x)$.
272265
273266 Args:
274267 sample_shape: Desired shape of samples that are drawn from posterior. If
275268 sample_shape is multidimensional we simply draw `sample_shape.numel()`
276269 samples and then reshape into the desired shape.
277- mcmc_parameters: Dictionary that is passed only to support the API of
278- `sbi` v0.17.2 or older.
279- mcmc_method: This argument only exists to keep backward-compatibility with
280- `sbi` v0.17.2 or older. Please use `method` instead.
281- sample_with: This argument only exists to keep backward-compatibility with
282- `sbi` v0.17.2 or older. If it is set, we instantly raise an error.
270+ x: Conditioning observation $x_o$. If not provided, uses the default `x`
271+ set via `.set_default_x()`.
272+ method: MCMC method to use. One of `slice_np`, `slice_np_vectorized`,
273+ `hmc_pyro`, `nuts_pyro`, `slice_pymc`, `hmc_pymc`, `nuts_pymc`.
274+ If not provided, uses the method specified at initialization.
275+ thin: Thinning factor for the chain. If not provided, uses the value
276+ specified at initialization.
277+ warmup_steps: Number of warmup steps to discard. If not provided, uses
278+ the value specified at initialization.
279+ num_chains: Number of MCMC chains to run. If not provided, uses the
280+ value specified at initialization.
281+ init_strategy: Initialization strategy for chains (`proposal`, `sir`,
282+ or `resample`). If not provided, uses the value specified at
283+ initialization.
284+ init_strategy_parameters: Parameters for the initialization strategy.
285+ If not provided, uses the value specified at initialization.
286+ num_workers: Number of CPU cores for parallelization. If not provided,
287+ uses the value specified at initialization.
288+ mp_context: Multiprocessing context (`fork` or `spawn`). If not provided,
289+ uses the value specified at initialization.
283290 show_progress_bars: Whether to show sampling progress monitor.
284291
285292 Returns:
@@ -301,46 +308,6 @@ def sample(
301308 if init_strategy_parameters is None
302309 else init_strategy_parameters
303310 )
304- if init_strategy_num_candidates is not None :
305- warn (
306- f"Passing `init_strategy_num_candidates` is deprecated as of sbi \
307- v0.19.0. Instead, use e.g., \
308- `init_strategy_parameters={ 'num_candidate_samples' : 1000} `" ,
309- stacklevel = 2 ,
310- )
311- self .init_strategy_parameters ["num_candidate_samples" ] = (
312- init_strategy_num_candidates
313- )
314- if sample_with is not None :
315- raise ValueError (
316- f"You set `sample_with={ sample_with } `. As of sbi v0.18.0, setting "
317- "`sample_with` is no longer supported. You have to rerun "
318- f"`.build_posterior(sample_with={ sample_with } ).`"
319- )
320- if mcmc_method is not None :
321- warn (
322- "You passed `mcmc_method` to `.sample()`. As of sbi v0.18.0, this "
323- "is deprecated and will be removed in a future release. Use `method` "
324- "instead of `mcmc_method`." ,
325- stacklevel = 2 ,
326- )
327- method = mcmc_method
328- if mcmc_parameters :
329- warn (
330- "You passed `mcmc_parameters` to `.sample()`. As of sbi v0.18.0, this "
331- "is deprecated and will be removed in a future release. Instead, pass "
332- "the variable to `.sample()` directly, e.g. "
333- "`posterior.sample((1,), num_chains=5)`." ,
334- stacklevel = 2 ,
335- )
336- # The following lines are only for backwards compatibility with sbi v0.17.2 or
337- # older.
338- m_p = mcmc_parameters or {} # define to shorten the variable name
339- method = _maybe_use_dict_entry (method , "mcmc_method" , m_p )
340- thin = _maybe_use_dict_entry (thin , "thin" , m_p )
341- warmup_steps = _maybe_use_dict_entry (warmup_steps , "warmup_steps" , m_p )
342- num_chains = _maybe_use_dict_entry (num_chains , "num_chains" , m_p )
343- init_strategy = _maybe_use_dict_entry (init_strategy , "init_strategy" , m_p )
344311 self .potential_ = self ._prepare_potential (method ) # type: ignore
345312
346313 initial_params = self ._get_initial_params (
@@ -415,9 +382,10 @@ def sample_batched(
415382 mp_context : Optional [str ] = None ,
416383 show_progress_bars : bool = True ,
417384 ) -> Tensor :
418- r"""Given a batch of observations [x_1, ..., x_B] this function samples from
419- posteriors $p(\theta|x_1)$, ... ,$p(\theta|x_B)$, in a batched (i.e. vectorized)
420- manner.
385+ r"""Draw samples from the posteriors for a batch of different xs.
386+
387+ Given a batch of observations `[x_1, ..., x_B]`, this method samples from
388+ posteriors $p(\theta|x_1), \ldots, p(\theta|x_B)$ in a vectorized manner.
421389
422390 Check the `__init__()` method for a description of all arguments as well as
423391 their default values.
@@ -1134,24 +1102,6 @@ def _process_thin_default(thin: int) -> int:
11341102 return thin
11351103
11361104
1137- def _maybe_use_dict_entry (default : Any , key : str , dict_to_check : Dict ) -> Any :
1138- """Returns `default` if `key` is not in the dict and otherwise the dict entry.
1139-
1140- This method exists only to keep backwards compatibility with `sbi` v0.17.2 or
1141- older. It allows passing `mcmc_parameters` to `.sample()`.
1142-
1143- Args:
1144- default: The default value if `key` is not in `dict_to_check`.
1145- key: The key for which to check in `dict_to_check`.
1146- dict_to_check: The dictionary to be checked.
1147-
1148- Returns:
1149- The potentially replaced value.
1150- """
1151- attribute = dict_to_check .get (key , default )
1152- return attribute
1153-
1154-
11551105def _num_required_args (func ):
11561106 """
11571107 Utility for counting the number of positional args in a function.
0 commit comments