11# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
22# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33
4- from functools import partial
54from typing import Optional , Tuple
65
76import torch
@@ -79,41 +78,36 @@ def set_x(
7978 self ,
8079 x_o : Optional [Tensor ],
8180 x_is_iid : Optional [bool ] = False ,
82- rebuild_flow : Optional [bool ] = True ,
81+ atol : float = 1e-5 ,
82+ rtol : float = 1e-6 ,
83+ exact : bool = True ,
8384 ):
8485 """
8586 Set the observed data and whether it is IID.
87+
88+ Rebuids the continuous normalizing flow if the observed data is set.
89+
8690 Args:
87- x_o: The observed data.
88- x_is_iid: Whether the observed data is IID (if batch_dim>1).
89- rebuild_flow: Whether to save (overwrrite) a low-tolerance flow model, useful if
90- the flow needs to be evaluated many times (e.g. for MAP calculation).
91+ x_o: The observed data.
92+ x_is_iid: Whether the observed data is IID (if batch_dim>1).
93+ atol: Absolute tolerance for the ODE solver.
94+ rtol: Relative tolerance for the ODE solver.
95+ exact: Whether to use the exact ODE solver.
9196 """
9297 super ().set_x (x_o , x_is_iid )
93- if rebuild_flow and self ._x_o is not None :
94- # By default, we want a high-tolerance flow.
95- # This flow will be used mainly for MAP calculations, hence we want to save
96- # it instead of rebuilding it every time.
97- self .flow = self .rebuild_flow (atol = 1e-2 , rtol = 1e-3 , exact = True )
98+ if self ._x_o is not None :
99+ self .flow = self .rebuild_flow (atol = atol , rtol = rtol , exact = exact )
98100
99101 def __call__ (
100102 self ,
101103 theta : Tensor ,
102104 track_gradients : bool = True ,
103- rebuild_flow : bool = True ,
104- atol : float = 1e-5 ,
105- rtol : float = 1e-6 ,
106- exact : bool = True ,
107105 ) -> Tensor :
108106 """Return the potential (posterior log prob) via probability flow ODE.
109107
110108 Args:
111109 theta: The parameters at which to evaluate the potential.
112110 track_gradients: Whether to track gradients.
113- rebuild_flow: Whether to rebuild the CNF for accurate log_prob evaluation.
114- atol: Absolute tolerance for the ODE solver.
115- rtol: Relative tolerance for the ODE solver.
116- exact: Whether to use the exact ODE solver.
117111
118112 Returns:
119113 The potential function, i.e., the log probability of the posterior.
@@ -123,15 +117,9 @@ def __call__(
123117 theta , theta .shape [1 :], leading_is_sample = True
124118 )
125119 self .score_estimator .eval ()
126- # use rebuild_flow to evaluate log_prob with better precision, without
127- # overwriting self.flow
128- if rebuild_flow or self .flow is None :
129- flow = self .rebuild_flow (atol = atol , rtol = rtol , exact = exact )
130- else :
131- flow = self .flow
132120
133121 with torch .set_grad_enabled (track_gradients ):
134- log_probs = flow .log_prob (theta_density_estimator ).squeeze (- 1 )
122+ log_probs = self . flow .log_prob (theta_density_estimator ).squeeze (- 1 )
135123 # Force probability to be zero outside prior support.
136124 in_prior_support = within_support (self .prior , theta )
137125
@@ -217,7 +205,7 @@ def rebuild_flow(
217205 x_density_estimator = reshape_to_batch_event (
218206 self .x_o , event_shape = self .score_estimator .condition_shape
219207 )
220- assert x_density_estimator .shape [0 ] == 1 , (
208+ assert x_density_estimator .shape [0 ] == 1 or not self . x_is_iid , (
221209 "PosteriorScoreBasedPotential supports only x batchsize of 1`."
222210 )
223211
@@ -312,9 +300,8 @@ def __init__(self, posterior_score_based_potential):
312300 self .posterior_score_based_potential = posterior_score_based_potential
313301
314302 def __call__ (self , input ):
315- prepared_potential = partial (
316- self .posterior_score_based_potential .__call__ , rebuild_flow = False
317- )
318303 return DifferentiablePotentialFunction .apply (
319- input , prepared_potential , self .posterior_score_based_potential .gradient
304+ input ,
305+ self .posterior_score_based_potential .__call__ ,
306+ self .posterior_score_based_potential .gradient ,
320307 )
0 commit comments