11# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
22# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33
4- from typing import Callable , Optional , Tuple
4+ import warnings
5+ from typing import Callable , List , Optional , Tuple
56
67import torch
78from torch import Tensor
@@ -115,6 +116,54 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
115116 )
116117 return log_likelihood_batches + self .prior .log_prob (theta ) # type: ignore
117118
119+ def condition_on_theta (
120+ self , local_theta : Tensor , dims_global_theta : List [int ]
121+ ) -> Callable :
122+ r"""Returns a potential function conditioned on a subset of theta dimensions.
123+
124+ The goal of this function is to divide the original `theta` into a
125+ `global_theta` we do inference over, and a `local_theta` we condition on (in
126+ addition to conditioning on `x_o`). Thus, the returned potential function will
127+ calculate $\prod_{i=1}^{N}p(x_i | local_theta_i, \global_theta)$, where `x_i`
128+ and `local_theta_i` are fixed and `global_theta` varies at inference time.
129+
130+ Args:
131+ local_theta: The condition values to be conditioned.
132+ dims_global_theta: The indices of the columns in `theta` that will be
133+ sampled, i.e., that *not* conditioned. For example, if original theta
134+ has shape `(batch_dim, 3)`, and `dims_global_theta=[0, 1]`, then the
135+ potential will set `theta[:, 3] = local_theta` at inference time.
136+
137+ Returns:
138+ A potential function conditioned on the `local_theta`.
139+ """
140+
141+ assert self .x_is_iid , "Conditioning is only supported for iid data."
142+
143+ def conditioned_potential (
144+ theta : Tensor , x_o : Optional [Tensor ] = None , track_gradients : bool = True
145+ ) -> Tensor :
146+ assert (
147+ len (dims_global_theta ) == theta .shape [1 ]
148+ ), "dims_global_theta must match the number of parameters to sample."
149+ global_theta = theta [:, dims_global_theta ]
150+ x_o = x_o if x_o is not None else self .x_o
151+ # x needs shape (sample_dim (iid), batch_dim (xs), *event_shape)
152+ if x_o .dim () < 3 :
153+ x_o = reshape_to_sample_batch_event (
154+ x_o , event_shape = x_o .shape [1 :], leading_is_sample = self .x_is_iid
155+ )
156+
157+ return _log_likelihood_over_iid_trials_and_local_theta (
158+ x = x_o ,
159+ global_theta = global_theta ,
160+ local_theta = local_theta ,
161+ estimator = self .likelihood_estimator ,
162+ track_gradients = track_gradients ,
163+ )
164+
165+ return conditioned_potential
166+
118167
119168def _log_likelihoods_over_trials (
120169 x : Tensor ,
@@ -172,6 +221,77 @@ def _log_likelihoods_over_trials(
172221 return log_likelihood_trial_sum
173222
174223
224+ def _log_likelihood_over_iid_trials_and_local_theta (
225+ x : Tensor ,
226+ global_theta : Tensor ,
227+ local_theta : Tensor ,
228+ estimator : ConditionalDensityEstimator ,
229+ track_gradients : bool = False ,
230+ ) -> Tensor :
231+ """Returns $\\ prod_{i=1}^N \\ log(p(x_i|\t heta, local_theta_i)$.
232+
233+ `x` is a batch of iid data, and `local_theta` is a matching batch of condition
234+ values that were part of `theta` but are treated as local iid variables at inference
235+ time.
236+
237+ This function is different from `_log_likelihoods_over_trials` in that it moves the
238+ iid batch dimension of `x` onto the batch dimension of `theta`. This is needed when
239+ the likelihood estimator is conditioned on a batch of conditions that are iid with
240+ the batch of `x`. It avoids the evaluation of the likelihood for every combination
241+ of `x` and `local_theta`.
242+
243+ Args:
244+ x: data with shape `(sample_dim, x_batch_dim, *x_event_shape)`, where sample_dim
245+ holds the i.i.d. trials and batch_dim holds a batch of xs, e.g., non-iid
246+ observations.
247+ global_theta: Batch of parameters `(theta_batch_dim,
248+ num_parameters)`.
249+ local_theta: Batch of conditions of shape `(sample_dim, num_local_thetas)`, must
250+ match x's `sample_dim`.
251+ estimator: DensityEstimator.
252+ track_gradients: Whether to track gradients.
253+
254+ Returns:
255+ log_likelihood: log likelihood for each x in x_batch_dim, for each theta in
256+ theta_batch_dim, summed over all iid trials. Shape `(x_batch_dim,
257+ theta_batch_dim)`.
258+ """
259+ assert x .dim () > 2 , "x must have shape (sample_dim, batch_dim, *event_shape)."
260+ assert (
261+ local_theta .dim () == 2
262+ ), "condition must have shape (sample_dim, num_conditions)."
263+ assert global_theta .dim () == 2 , "theta must have shape (batch_dim, num_parameters)."
264+ num_trials , num_xs = x .shape [:2 ]
265+ num_thetas = global_theta .shape [0 ]
266+ assert (
267+ local_theta .shape [0 ] == num_trials
268+ ), "Condition batch size must match the number of iid trials in x."
269+
270+ # move the iid batch dimension onto the batch dimension of theta and repeat it there
271+ x_repeated = torch .transpose (x , 0 , 1 ).repeat_interleave (num_thetas , dim = 1 )
272+
273+ # construct theta and condition to cover all trial-theta combinations
274+ theta_with_condition = torch .cat (
275+ [
276+ global_theta .repeat (num_trials , 1 ), # repeat ABAB
277+ local_theta .repeat_interleave (num_thetas , dim = 0 ), # repeat AABB
278+ ],
279+ dim = - 1 ,
280+ )
281+
282+ with torch .set_grad_enabled (track_gradients ):
283+ # Calculate likelihood in one batch. Returns (1, num_trials * num_theta)
284+ log_likelihood_trial_batch = estimator .log_prob (
285+ x_repeated , condition = theta_with_condition
286+ )
287+ # Reshape to (x-trials x parameters), sum over trial-log likelihoods.
288+ log_likelihood_trial_sum = log_likelihood_trial_batch .reshape (
289+ num_xs , num_trials , num_thetas
290+ ).sum (1 )
291+
292+ return log_likelihood_trial_sum
293+
294+
175295def mixed_likelihood_estimator_based_potential (
176296 likelihood_estimator : MixedDensityEstimator ,
177297 prior : Distribution ,
@@ -192,6 +312,13 @@ def mixed_likelihood_estimator_based_potential(
192312 to unconstrained space.
193313 """
194314
315+ warnings .warn (
316+ "This function is deprecated and will be removed in a future release. Use "
317+ "`likelihood_estimator_based_potential` instead." ,
318+ DeprecationWarning ,
319+ stacklevel = 2 ,
320+ )
321+
195322 device = str (next (likelihood_estimator .discrete_net .parameters ()).device )
196323
197324 potential_fn = MixedLikelihoodBasedPotential (
@@ -212,6 +339,13 @@ def __init__(
212339 ):
213340 super ().__init__ (likelihood_estimator , prior , x_o , device )
214341
342+ warnings .warn (
343+ "This function is deprecated and will be removed in a future release. Use "
344+ "`LikelihoodBasedPotential` instead." ,
345+ DeprecationWarning ,
346+ stacklevel = 2 ,
347+ )
348+
215349 def __call__ (self , theta : Tensor , track_gradients : bool = True ) -> Tensor :
216350 prior_log_prob = self .prior .log_prob (theta ) # type: ignore
217351
@@ -231,7 +365,6 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
231365 with torch .set_grad_enabled (track_gradients ):
232366 # Call the specific log prob method of the mixed likelihood estimator as
233367 # this optimizes the evaluation of the discrete data part.
234- # TODO log_prob_iid
235368 log_likelihood_trial_batch = self .likelihood_estimator .log_prob (
236369 input = x ,
237370 condition = theta .to (self .device ),
0 commit comments