@@ -84,15 +84,26 @@ class AutoGuide(ABC):
8484 ``*args,**kwargs`` as ``model()`` and returning a :class:`numpyro.plate`
8585 or iterable of plates. Plates not returned will be created
8686 automatically as usual. This is useful for data subsampling.
87+ :param bool forward_mode_differentiation: Whether to use forward-mode differentiation
88+ during model initialization. Defaults to False. This is useful for models that
89+ contain JAX primitives which are not supported by reverse-mode differentiation
90+ (e.g. :func:`jax.lax.while_loop`).
8791 """
8892
8993 def __init__ (
90- self , model , * , prefix = "auto" , init_loc_fn = init_to_uniform , create_plates = None
94+ self ,
95+ model ,
96+ * ,
97+ prefix = "auto" ,
98+ init_loc_fn = init_to_uniform ,
99+ create_plates = None ,
100+ forward_mode_differentiation = False ,
91101 ):
92102 self .model = model
93103 self .prefix = prefix
94104 self .init_loc_fn = init_loc_fn
95105 self .create_plates = create_plates
106+ self ._forward_mode_differentiation = forward_mode_differentiation
96107 self .prototype_trace = None
97108 self ._prototype_frames = {}
98109 self ._prototype_frame_full_sizes = {}
@@ -164,6 +175,7 @@ def _setup_prototype(self, *args, **kwargs):
164175 dynamic_args = True ,
165176 model_args = args ,
166177 model_kwargs = kwargs ,
178+ forward_mode_differentiation = self ._forward_mode_differentiation ,
167179 )
168180 self ._potential_fn = self ._potential_fn_gen (* args , ** kwargs )
169181 postprocess_fn = postprocess_fn_gen (* args , ** kwargs )
@@ -246,14 +258,26 @@ class AutoGuideList(AutoGuide):
246258 params = svi.get_params(svi_state)
247259
248260 :param callable model: a NumPyro model
261+ :param bool forward_mode_differentiation: Whether to use forward-mode differentiation
262+ during model initialization. Defaults to False.
249263 """
250264
251265 def __init__ (
252- self , model , * , prefix = "auto" , init_loc_fn = init_to_uniform , create_plates = None
266+ self ,
267+ model ,
268+ * ,
269+ prefix = "auto" ,
270+ init_loc_fn = init_to_uniform ,
271+ create_plates = None ,
272+ forward_mode_differentiation = False ,
253273 ):
254274 self ._guides = []
255275 super ().__init__ (
256- model , prefix = prefix , init_loc_fn = init_loc_fn , create_plates = create_plates
276+ model ,
277+ prefix = prefix ,
278+ init_loc_fn = init_loc_fn ,
279+ create_plates = create_plates ,
280+ forward_mode_differentiation = forward_mode_differentiation ,
257281 )
258282
259283 def append (self , part ):
@@ -363,6 +387,8 @@ class AutoNormal(AutoGuide):
363387 ``*args,**kwargs`` as ``model()`` and returning a :class:`numpyro.plate`
364388 or iterable of plates. Plates not returned will be created
365389 automatically as usual. This is useful for data subsampling.
390+ :param bool forward_mode_differentiation: Whether to use forward-mode differentiation
391+ during model initialization. Defaults to False.
366392 """
367393
368394 scale_constraint = constraints .softplus_positive
@@ -375,11 +401,16 @@ def __init__(
375401 init_loc_fn = init_to_uniform ,
376402 init_scale = 0.1 ,
377403 create_plates = None ,
404+ forward_mode_differentiation = False ,
378405 ):
379406 self ._init_scale = init_scale
380407 self ._event_dims = {}
381408 super ().__init__ (
382- model , prefix = prefix , init_loc_fn = init_loc_fn , create_plates = create_plates
409+ model ,
410+ prefix = prefix ,
411+ init_loc_fn = init_loc_fn ,
412+ create_plates = create_plates ,
413+ forward_mode_differentiation = forward_mode_differentiation ,
383414 )
384415
385416 def _setup_prototype (self , * args , ** kwargs ):
@@ -516,14 +547,26 @@ class AutoDelta(AutoGuide):
516547 ``*args,**kwargs`` as ``model()`` and returning a :class:`numpyro.plate`
517548 or iterable of plates. Plates not returned will be created
518549 automatically as usual. This is useful for data subsampling.
550+ :param bool forward_mode_differentiation: Whether to use forward-mode differentiation
551+ during model initialization. Defaults to False.
519552 """
520553
521554 def __init__ (
522- self , model , * , prefix = "auto" , init_loc_fn = init_to_median , create_plates = None
555+ self ,
556+ model ,
557+ * ,
558+ prefix = "auto" ,
559+ init_loc_fn = init_to_median ,
560+ create_plates = None ,
561+ forward_mode_differentiation = False ,
523562 ):
524563 self ._event_dims = {}
525564 super ().__init__ (
526- model , prefix = prefix , init_loc_fn = init_loc_fn , create_plates = create_plates
565+ model ,
566+ prefix = prefix ,
567+ init_loc_fn = init_loc_fn ,
568+ create_plates = create_plates ,
569+ forward_mode_differentiation = forward_mode_differentiation ,
527570 )
528571
529572 def _setup_prototype (self , * args , ** kwargs ):
@@ -853,6 +896,8 @@ class AutoDAIS(AutoContinuous):
853896 :param float init_scale: Initial scale for the standard deviation of
854897 the base variational distribution for each (unconstrained transformed)
855898 latent variable. Defaults to 0.1.
899+ :param bool forward_mode_differentiation: Whether to use forward-mode differentiation
900+ during model initialization. Defaults to False.
856901 """
857902
858903 def __init__ (
@@ -867,6 +912,7 @@ def __init__(
867912 prefix = "auto" ,
868913 init_loc_fn = init_to_uniform ,
869914 init_scale = 0.1 ,
915+ forward_mode_differentiation = False ,
870916 ):
871917 if K < 1 :
872918 raise ValueError ("K must satisfy K >= 1 (got K = {})" .format (K ))
@@ -889,7 +935,12 @@ def __init__(
889935 self .K = K
890936 self .base_dist = base_dist
891937 self ._init_scale = init_scale
892- super ().__init__ (model , prefix = prefix , init_loc_fn = init_loc_fn )
938+ super ().__init__ (
939+ model ,
940+ prefix = prefix ,
941+ init_loc_fn = init_loc_fn ,
942+ forward_mode_differentiation = forward_mode_differentiation ,
943+ )
893944
894945 def _setup_prototype (self , * args , ** kwargs ):
895946 super ()._setup_prototype (* args , ** kwargs )
@@ -1083,6 +1134,8 @@ def surrogate_model(X_surr, Y_surr):
10831134 :param float init_scale: Initial scale for the standard deviation of
10841135 the base variational distribution for each (unconstrained transformed)
10851136 latent variable. Defaults to 0.1.
1137+ :param bool forward_mode_differentiation: Whether to use forward-mode differentiation
1138+ during model initialization. Defaults to False.
10861139 """
10871140
10881141 def __init__ (
@@ -1098,6 +1151,7 @@ def __init__(
10981151 base_dist = "diagonal" ,
10991152 init_loc_fn = init_to_uniform ,
11001153 init_scale = 0.1 ,
1154+ forward_mode_differentiation = False ,
11011155 ):
11021156 super ().__init__ (
11031157 model ,
@@ -1109,6 +1163,7 @@ def __init__(
11091163 init_loc_fn = init_loc_fn ,
11101164 init_scale = init_scale ,
11111165 base_dist = base_dist ,
1166+ forward_mode_differentiation = forward_mode_differentiation ,
11121167 )
11131168
11141169 self .surrogate_model = surrogate_model
@@ -1127,6 +1182,7 @@ def _setup_prototype(self, *args, **kwargs):
11271182 dynamic_args = False ,
11281183 model_args = (),
11291184 model_kwargs = {},
1185+ forward_mode_differentiation = self ._forward_mode_differentiation ,
11301186 )
11311187 )
11321188
@@ -1299,6 +1355,8 @@ def local_model(theta):
12991355 data points in the subsample plate) or local (i.e. each data point in the
13001356 subsample plate has individual parameters). Note that we do not use global
13011357 parameters for the base distribution.
1358+ :param bool forward_mode_differentiation: Whether to use forward-mode differentiation
1359+ during model initialization. Defaults to False.
13021360 """
13031361
13041362 def __init__ (
@@ -1316,9 +1374,14 @@ def __init__(
13161374 init_scale = 0.1 ,
13171375 subsample_plate = None ,
13181376 use_global_dais_params = False ,
1377+ forward_mode_differentiation = False ,
13191378 ):
1320- # init_loc_fn is only used to inspect the model.
1321- super ().__init__ (model , prefix = prefix , init_loc_fn = init_to_uniform )
1379+ super ().__init__ (
1380+ model ,
1381+ prefix = prefix ,
1382+ init_loc_fn = init_to_uniform ,
1383+ forward_mode_differentiation = forward_mode_differentiation ,
1384+ )
13221385 if K < 1 :
13231386 raise ValueError ("K must satisfy K >= 1 (got K = {})" .format (K ))
13241387 if eta_init <= 0.0 or eta_init >= eta_max :
0 commit comments