99
1010from sbi .inference .posteriors .base_posterior import NeuralPosterior
1111from sbi .inference .potentials .score_based_potential import (
12+ CallableDifferentiablePotentialFunction ,
1213 PosteriorScoreBasedPotential ,
1314 score_estimator_based_potential ,
1415)
1516from sbi .neural_nets .estimators .score_estimator import ConditionalScoreEstimator
1617from sbi .neural_nets .estimators .shape_handling import (
1718 reshape_to_batch_event ,
1819)
20+ from sbi .samplers .rejection import rejection
1921from sbi .samplers .score .correctors import Corrector
2022from sbi .samplers .score .diffuser import Diffuser
2123from sbi .samplers .score .predictors import Predictor
2224from sbi .sbi_types import Shape
2325from sbi .utils import check_prior
26+ from sbi .utils .sbiutils import gradient_ascent , within_support
2427from sbi .utils .torchutils import ensure_theta_batched
2528
2629
@@ -46,7 +49,7 @@ def __init__(
4649 prior : Distribution ,
4750 max_sampling_batch_size : int = 10_000 ,
4851 device : Optional [str ] = None ,
49- enable_transform : bool = False ,
52+ enable_transform : bool = True ,
5053 sample_with : str = "sde" ,
5154 ):
5255 """
@@ -110,7 +113,6 @@ def sample(
110113
111114 Args:
112115 sample_shape: Shape of the samples to be drawn.
113- x: Deprecated - use `.set_default_x()` prior to `.sample()`.
114116 predictor: The predictor for the diffusion-based sampler. Can be a string or
115117 a custom predictor following the API in `sbi.samplers.score.predictors`.
116118 Currently, only `euler_maruyama` is implemented.
@@ -136,23 +138,39 @@ def sample(
136138
137139 x = self ._x_else_default_x (x )
138140 x = reshape_to_batch_event (x , self .score_estimator .condition_shape )
139- self .potential_fn .set_x (x )
141+ self .potential_fn .set_x (x , x_is_iid = True )
142+
143+ num_samples = torch .Size (sample_shape ).numel ()
140144
141145 if self .sample_with == "ode" :
142- samples = self .sample_via_zuko (sample_shape = sample_shape , x = x )
143- elif self .sample_with == "sde" :
144- samples = self ._sample_via_diffusion (
145- sample_shape = sample_shape ,
146- predictor = predictor ,
147- corrector = corrector ,
148- predictor_params = predictor_params ,
149- corrector_params = corrector_params ,
150- steps = steps ,
151- ts = ts ,
146+ samples = rejection .accept_reject_sample (
147+ proposal = self .sample_via_ode ,
148+ accept_reject_fn = lambda theta : within_support (self .prior , theta ),
149+ num_samples = num_samples ,
150+ show_progress_bars = show_progress_bars ,
152151 max_sampling_batch_size = max_sampling_batch_size ,
152+ )[0 ]
153+ elif self .sample_with == "sde" :
154+ proposal_sampling_kwargs = {
155+ "predictor" : predictor ,
156+ "corrector" : corrector ,
157+ "predictor_params" : predictor_params ,
158+ "corrector_params" : corrector_params ,
159+ "steps" : steps ,
160+ "ts" : ts ,
161+ "max_sampling_batch_size" : max_sampling_batch_size ,
162+ "show_progress_bars" : show_progress_bars ,
163+ }
164+ samples = rejection .accept_reject_sample (
165+ proposal = self ._sample_via_diffusion ,
166+ accept_reject_fn = lambda theta : within_support (self .prior , theta ),
167+ num_samples = num_samples ,
153168 show_progress_bars = show_progress_bars ,
154- )
169+ max_sampling_batch_size = max_sampling_batch_size ,
170+ proposal_sampling_kwargs = proposal_sampling_kwargs ,
171+ )[0 ]
155172
173+ samples = samples .reshape (sample_shape + self .score_estimator .input_shape )
156174 return samples
157175
158176 def _sample_via_diffusion (
@@ -171,7 +189,6 @@ def _sample_via_diffusion(
171189
172190 Args:
173191 sample_shape: Shape of the samples to be drawn.
174- x: Deprecated - use `.set_default_x()` prior to `.sample()`.
175192 predictor: The predictor for the diffusion-based sampler. Can be a string or
176193 a custom predictor following the API in `sbi.samplers.score.predictors`.
177194 Currently, only `euler_maruyama` is implemented.
@@ -222,11 +239,10 @@ def _sample_via_diffusion(
222239 )
223240 samples = torch .cat (samples , dim = 0 )[:num_samples ]
224241
225- return samples . reshape ( sample_shape + self . score_estimator . input_shape )
242+ return samples
226243
227- def sample_via_zuko (
244+ def sample_via_ode (
228245 self ,
229- x : Tensor ,
230246 sample_shape : Shape = torch .Size (),
231247 ) -> Tensor :
232248 r"""Return samples from posterior distribution with probability flow ODE.
@@ -243,10 +259,12 @@ def sample_via_zuko(
243259 """
244260 num_samples = torch .Size (sample_shape ).numel ()
245261
246- flow = self .potential_fn .get_continuous_normalizing_flow (condition = x )
262+ flow = self .potential_fn .get_continuous_normalizing_flow (
263+ condition = self .potential_fn .x_o
264+ )
247265 samples = flow .sample (torch .Size ((num_samples ,)))
248266
249- return samples . reshape ( sample_shape + self . score_estimator . input_shape )
267+ return samples
250268
251269 def log_prob (
252270 self ,
@@ -291,19 +309,73 @@ def sample_batched(
291309 self ,
292310 sample_shape : torch .Size ,
293311 x : Tensor ,
312+ predictor : Union [str , Predictor ] = "euler_maruyama" ,
313+ corrector : Optional [Union [str , Corrector ]] = None ,
314+ predictor_params : Optional [Dict ] = None ,
315+ corrector_params : Optional [Dict ] = None ,
316+ steps : int = 500 ,
317+ ts : Optional [Tensor ] = None ,
294318 max_sampling_batch_size : int = 10000 ,
295319 show_progress_bars : bool = True ,
296320 ) -> Tensor :
297- raise NotImplementedError (
298- "Batched sampling is not implemented for ScorePosterior."
321+ num_samples = torch .Size (sample_shape ).numel ()
322+ x = reshape_to_batch_event (x , self .score_estimator .condition_shape )
323+ condition_dim = len (self .score_estimator .condition_shape )
324+ batch_shape = x .shape [:- condition_dim ]
325+ batch_size = batch_shape .numel ()
326+ self .potential_fn .set_x (x )
327+
328+ max_sampling_batch_size = (
329+ self .max_sampling_batch_size
330+ if max_sampling_batch_size is None
331+ else max_sampling_batch_size
299332 )
300333
334+ if self .sample_with == "ode" :
335+ samples = rejection .accept_reject_sample (
336+ proposal = self .sample_via_ode ,
337+ accept_reject_fn = lambda theta : within_support (self .prior , theta ),
338+ num_samples = num_samples ,
339+ num_xos = batch_size ,
340+ show_progress_bars = show_progress_bars ,
341+ max_sampling_batch_size = max_sampling_batch_size ,
342+ proposal_sampling_kwargs = {"x" : x },
343+ )[0 ]
344+ samples = samples .reshape (
345+ sample_shape + batch_shape + self .score_estimator .input_shape
346+ )
347+ elif self .sample_with == "sde" :
348+ proposal_sampling_kwargs = {
349+ "predictor" : predictor ,
350+ "corrector" : corrector ,
351+ "predictor_params" : predictor_params ,
352+ "corrector_params" : corrector_params ,
353+ "steps" : steps ,
354+ "ts" : ts ,
355+ "max_sampling_batch_size" : max_sampling_batch_size ,
356+ "show_progress_bars" : show_progress_bars ,
357+ }
358+ samples = rejection .accept_reject_sample (
359+ proposal = self ._sample_via_diffusion ,
360+ accept_reject_fn = lambda theta : within_support (self .prior , theta ),
361+ num_samples = num_samples ,
362+ num_xos = batch_size ,
363+ show_progress_bars = show_progress_bars ,
364+ max_sampling_batch_size = max_sampling_batch_size ,
365+ proposal_sampling_kwargs = proposal_sampling_kwargs ,
366+ )[0 ]
367+ samples = samples .reshape (
368+ sample_shape + batch_shape + self .score_estimator .input_shape
369+ )
370+
371+ return samples
372+
301373 def map (
302374 self ,
303375 x : Optional [Tensor ] = None ,
304376 num_iter : int = 1000 ,
305377 num_to_optimize : int = 1000 ,
306- learning_rate : float = 1e-5 ,
378+ learning_rate : float = 0.01 ,
307379 init_method : Union [str , Tensor ] = "posterior" ,
308380 num_init_samples : int = 1000 ,
309381 save_best_every : int = 1000 ,
@@ -351,17 +423,41 @@ def map(
351423 Returns:
352424 The MAP estimate.
353425 """
354- raise NotImplementedError (
355- "MAP estimation is currently not working accurately for ScorePosterior."
356- )
357- return super ().map (
358- x = x ,
359- num_iter = num_iter ,
360- num_to_optimize = num_to_optimize ,
361- learning_rate = learning_rate ,
362- init_method = init_method ,
363- num_init_samples = num_init_samples ,
364- save_best_every = save_best_every ,
365- show_progress_bars = show_progress_bars ,
366- force_update = force_update ,
367- )
426+ if x is not None :
427+ raise ValueError (
428+ "Passing `x` directly to `.map()` has been deprecated."
429+ "Use `.self_default_x()` to set `x`, and then run `.map()` "
430+ )
431+
432+ if self .default_x is None :
433+ raise ValueError (
434+ "Default `x` has not been set."
435+ "To set the default, use the `.set_default_x()` method."
436+ )
437+
438+ if self ._map is None or force_update :
439+ self .potential_fn .set_x (self .default_x )
440+ callable_potential_fn = CallableDifferentiablePotentialFunction (
441+ self .potential_fn
442+ )
443+ if init_method == "posterior" :
444+ inits = self .sample ((num_init_samples ,))
445+ elif init_method == "proposal" :
446+ inits = self .proposal .sample ((num_init_samples ,)) # type: ignore
447+ elif isinstance (init_method , Tensor ):
448+ inits = init_method
449+ else :
450+ raise ValueError
451+
452+ self ._map = gradient_ascent (
453+ potential_fn = callable_potential_fn ,
454+ inits = inits ,
455+ theta_transform = self .theta_transform ,
456+ num_iter = num_iter ,
457+ num_to_optimize = num_to_optimize ,
458+ learning_rate = learning_rate ,
459+ save_best_every = save_best_every ,
460+ show_progress_bars = show_progress_bars ,
461+ )[0 ]
462+
463+ return self ._map
0 commit comments