Skip to content

Commit 781b0e7

Browse files
authored
enhancement (#2163)
1 parent 0fcf121 commit 781b0e7

2 files changed

Lines changed: 84 additions & 9 deletions

File tree

numpyro/infer/autoguide.py

Lines changed: 72 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -84,15 +84,26 @@ class AutoGuide(ABC):
8484
``*args,**kwargs`` as ``model()`` and returning a :class:`numpyro.plate`
8585
or iterable of plates. Plates not returned will be created
8686
automatically as usual. This is useful for data subsampling.
87+
:param bool forward_mode_differentiation: Whether to use forward-mode differentiation
88+
during model initialization. Defaults to False. This is useful for models that
89+
contain JAX primitives which are not supported by reverse-mode differentiation
90+
(e.g. :func:`jax.lax.while_loop`).
8791
"""
8892

8993
def __init__(
90-
self, model, *, prefix="auto", init_loc_fn=init_to_uniform, create_plates=None
94+
self,
95+
model,
96+
*,
97+
prefix="auto",
98+
init_loc_fn=init_to_uniform,
99+
create_plates=None,
100+
forward_mode_differentiation=False,
91101
):
92102
self.model = model
93103
self.prefix = prefix
94104
self.init_loc_fn = init_loc_fn
95105
self.create_plates = create_plates
106+
self._forward_mode_differentiation = forward_mode_differentiation
96107
self.prototype_trace = None
97108
self._prototype_frames = {}
98109
self._prototype_frame_full_sizes = {}
@@ -164,6 +175,7 @@ def _setup_prototype(self, *args, **kwargs):
164175
dynamic_args=True,
165176
model_args=args,
166177
model_kwargs=kwargs,
178+
forward_mode_differentiation=self._forward_mode_differentiation,
167179
)
168180
self._potential_fn = self._potential_fn_gen(*args, **kwargs)
169181
postprocess_fn = postprocess_fn_gen(*args, **kwargs)
@@ -246,14 +258,26 @@ class AutoGuideList(AutoGuide):
246258
params = svi.get_params(svi_state)
247259
248260
:param callable model: a NumPyro model
261+
:param bool forward_mode_differentiation: Whether to use forward-mode differentiation
262+
during model initialization. Defaults to False.
249263
"""
250264

251265
def __init__(
252-
self, model, *, prefix="auto", init_loc_fn=init_to_uniform, create_plates=None
266+
self,
267+
model,
268+
*,
269+
prefix="auto",
270+
init_loc_fn=init_to_uniform,
271+
create_plates=None,
272+
forward_mode_differentiation=False,
253273
):
254274
self._guides = []
255275
super().__init__(
256-
model, prefix=prefix, init_loc_fn=init_loc_fn, create_plates=create_plates
276+
model,
277+
prefix=prefix,
278+
init_loc_fn=init_loc_fn,
279+
create_plates=create_plates,
280+
forward_mode_differentiation=forward_mode_differentiation,
257281
)
258282

259283
def append(self, part):
@@ -363,6 +387,8 @@ class AutoNormal(AutoGuide):
363387
``*args,**kwargs`` as ``model()`` and returning a :class:`numpyro.plate`
364388
or iterable of plates. Plates not returned will be created
365389
automatically as usual. This is useful for data subsampling.
390+
:param bool forward_mode_differentiation: Whether to use forward-mode differentiation
391+
during model initialization. Defaults to False.
366392
"""
367393

368394
scale_constraint = constraints.softplus_positive
@@ -375,11 +401,16 @@ def __init__(
375401
init_loc_fn=init_to_uniform,
376402
init_scale=0.1,
377403
create_plates=None,
404+
forward_mode_differentiation=False,
378405
):
379406
self._init_scale = init_scale
380407
self._event_dims = {}
381408
super().__init__(
382-
model, prefix=prefix, init_loc_fn=init_loc_fn, create_plates=create_plates
409+
model,
410+
prefix=prefix,
411+
init_loc_fn=init_loc_fn,
412+
create_plates=create_plates,
413+
forward_mode_differentiation=forward_mode_differentiation,
383414
)
384415

385416
def _setup_prototype(self, *args, **kwargs):
@@ -516,14 +547,26 @@ class AutoDelta(AutoGuide):
516547
``*args,**kwargs`` as ``model()`` and returning a :class:`numpyro.plate`
517548
or iterable of plates. Plates not returned will be created
518549
automatically as usual. This is useful for data subsampling.
550+
:param bool forward_mode_differentiation: Whether to use forward-mode differentiation
551+
during model initialization. Defaults to False.
519552
"""
520553

521554
def __init__(
522-
self, model, *, prefix="auto", init_loc_fn=init_to_median, create_plates=None
555+
self,
556+
model,
557+
*,
558+
prefix="auto",
559+
init_loc_fn=init_to_median,
560+
create_plates=None,
561+
forward_mode_differentiation=False,
523562
):
524563
self._event_dims = {}
525564
super().__init__(
526-
model, prefix=prefix, init_loc_fn=init_loc_fn, create_plates=create_plates
565+
model,
566+
prefix=prefix,
567+
init_loc_fn=init_loc_fn,
568+
create_plates=create_plates,
569+
forward_mode_differentiation=forward_mode_differentiation,
527570
)
528571

529572
def _setup_prototype(self, *args, **kwargs):
@@ -853,6 +896,8 @@ class AutoDAIS(AutoContinuous):
853896
:param float init_scale: Initial scale for the standard deviation of
854897
the base variational distribution for each (unconstrained transformed)
855898
latent variable. Defaults to 0.1.
899+
:param bool forward_mode_differentiation: Whether to use forward-mode differentiation
900+
during model initialization. Defaults to False.
856901
"""
857902

858903
def __init__(
@@ -867,6 +912,7 @@ def __init__(
867912
prefix="auto",
868913
init_loc_fn=init_to_uniform,
869914
init_scale=0.1,
915+
forward_mode_differentiation=False,
870916
):
871917
if K < 1:
872918
raise ValueError("K must satisfy K >= 1 (got K = {})".format(K))
@@ -889,7 +935,12 @@ def __init__(
889935
self.K = K
890936
self.base_dist = base_dist
891937
self._init_scale = init_scale
892-
super().__init__(model, prefix=prefix, init_loc_fn=init_loc_fn)
938+
super().__init__(
939+
model,
940+
prefix=prefix,
941+
init_loc_fn=init_loc_fn,
942+
forward_mode_differentiation=forward_mode_differentiation,
943+
)
893944

894945
def _setup_prototype(self, *args, **kwargs):
895946
super()._setup_prototype(*args, **kwargs)
@@ -1083,6 +1134,8 @@ def surrogate_model(X_surr, Y_surr):
10831134
:param float init_scale: Initial scale for the standard deviation of
10841135
the base variational distribution for each (unconstrained transformed)
10851136
latent variable. Defaults to 0.1.
1137+
:param bool forward_mode_differentiation: Whether to use forward-mode differentiation
1138+
during model initialization. Defaults to False.
10861139
"""
10871140

10881141
def __init__(
@@ -1098,6 +1151,7 @@ def __init__(
10981151
base_dist="diagonal",
10991152
init_loc_fn=init_to_uniform,
11001153
init_scale=0.1,
1154+
forward_mode_differentiation=False,
11011155
):
11021156
super().__init__(
11031157
model,
@@ -1109,6 +1163,7 @@ def __init__(
11091163
init_loc_fn=init_loc_fn,
11101164
init_scale=init_scale,
11111165
base_dist=base_dist,
1166+
forward_mode_differentiation=forward_mode_differentiation,
11121167
)
11131168

11141169
self.surrogate_model = surrogate_model
@@ -1127,6 +1182,7 @@ def _setup_prototype(self, *args, **kwargs):
11271182
dynamic_args=False,
11281183
model_args=(),
11291184
model_kwargs={},
1185+
forward_mode_differentiation=self._forward_mode_differentiation,
11301186
)
11311187
)
11321188

@@ -1299,6 +1355,8 @@ def local_model(theta):
12991355
data points in the subsample plate) or local (i.e. each data point in the
13001356
subsample plate has individual parameters). Note that we do not use global
13011357
parameters for the base distribution.
1358+
:param bool forward_mode_differentiation: Whether to use forward-mode differentiation
1359+
during model initialization. Defaults to False.
13021360
"""
13031361

13041362
def __init__(
@@ -1316,9 +1374,14 @@ def __init__(
13161374
init_scale=0.1,
13171375
subsample_plate=None,
13181376
use_global_dais_params=False,
1377+
forward_mode_differentiation=False,
13191378
):
1320-
# init_loc_fn is only used to inspect the model.
1321-
super().__init__(model, prefix=prefix, init_loc_fn=init_to_uniform)
1379+
super().__init__(
1380+
model,
1381+
prefix=prefix,
1382+
init_loc_fn=init_to_uniform,
1383+
forward_mode_differentiation=forward_mode_differentiation,
1384+
)
13221385
if K < 1:
13231386
raise ValueError("K must satisfy K >= 1 (got K = {})".format(K))
13241387
if eta_init <= 0.0 or eta_init >= eta_max:

test/infer/test_autoguide.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1378,3 +1378,15 @@ def model(n: int, x: jnp.ndarray):
13781378
)
13791379
state = svi.init(jax.random.key(2), x=x)
13801380
svi.update(state, x=subset)
1381+
1382+
1383+
@pytest.mark.parametrize("auto_class", [AutoNormal, AutoDelta])
1384+
def test_autoguide_forward_mode_differentiation(auto_class):
1385+
def model():
1386+
x = numpyro.sample("x", dist.Normal(0, 1))
1387+
y = lax.while_loop(lambda x: x < 10, lambda x: x + 1, x)
1388+
numpyro.sample("obs", dist.Normal(y, 1), obs=1.0)
1389+
1390+
guide = auto_class(model, forward_mode_differentiation=True)
1391+
svi = SVI(model, guide, optim.Adam(0.01), loss=Trace_ELBO())
1392+
svi.run(random.key(0), 10, forward_mode_differentiation=True)

0 commit comments

Comments
 (0)