Skip to content
This repository was archived by the owner on Dec 5, 2024. It is now read-only.

Commit c62b624

Browse files
Modules reorganization (#140)
* Restructuration * Examples + tests * Comments
1 parent dfc7535 commit c62b624

40 files changed

+2204
-2181
lines changed

examples/gmrf.py

+10-12
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,7 @@
2222
from jax.example_libraries import optimizers
2323
from tqdm.notebook import tqdm
2424

25-
from pgmax.fg import graph
26-
from pgmax.groups import enumeration
27-
from pgmax.groups import variables as vgroup
25+
from pgmax import fgraph, fgroup, infer, vgroup
2826

2927
# %% [markdown]
3028
# # Visualize a trained GMRF
@@ -54,12 +52,12 @@
5452
# %%
5553
M, N = target_images.shape[-2:]
5654
num_states = np.sum(n_clones)
57-
variables = vgroup.NDVariableArray(num_states=num_states, shape=(M, N))
58-
fg = graph.FactorGraph(variables)
55+
variables = vgroup.NDVarArray(num_states=num_states, shape=(M, N))
56+
fg = fgraph.FactorGraph(variables)
5957

6058
# %%
6159
# Create top-down factors
62-
top_down = enumeration.PairwiseFactorGroup(
60+
top_down = fgroup.PairwiseFactorGroup(
6361
variables_for_factors=[
6462
[variables[ii, jj], variables[ii + 1, jj]]
6563
for ii in range(M - 1)
@@ -68,7 +66,7 @@
6866
)
6967

7068
# Create left-right factors
71-
left_right = enumeration.PairwiseFactorGroup(
69+
left_right = fgroup.PairwiseFactorGroup(
7270
variables_for_factors=[
7371
[variables[ii, jj], variables[ii, jj + 1]]
7472
for ii in range(M)
@@ -77,14 +75,14 @@
7775
)
7876

7977
# Create diagonal factors
80-
diagonal0 = enumeration.PairwiseFactorGroup(
78+
diagonal0 = fgroup.PairwiseFactorGroup(
8179
variables_for_factors=[
8280
[variables[ii, jj], variables[ii + 1, jj + 1]]
8381
for ii in range(M - 1)
8482
for jj in range(N - 1)
8583
],
8684
)
87-
diagonal1 = enumeration.PairwiseFactorGroup(
85+
diagonal1 = fgroup.PairwiseFactorGroup(
8886
variables_for_factors=[
8987
[variables[ii, jj], variables[ii - 1, jj + 1]]
9088
for ii in range(1, M)
@@ -96,7 +94,7 @@
9694
fg.add_factors([top_down, left_right, diagonal0, diagonal1])
9795

9896
# %%
99-
bp = graph.BP(fg.bp_state, temperature=1.0)
97+
bp = infer.BP(fg.bp_state, temperature=1.0)
10098

10199
# %%
102100
log_potentials = {
@@ -114,7 +112,7 @@
114112
target_image = target_images[idx]
115113
evidence = jnp.log(jnp.where(noisy_image[..., None] == 0, p_contour, 1 - p_contour))
116114
target = prototype_targets[target_image]
117-
marginals = graph.get_marginals(
115+
marginals = infer.get_marginals(
118116
bp.get_beliefs(
119117
bp.run_bp(
120118
bp.init(
@@ -162,7 +160,7 @@
162160
def loss(noisy_image, target_image, log_potentials):
163161
evidence = jnp.log(jnp.where(noisy_image[..., None] == 0, p_contour, 1 - p_contour))
164162
target = prototype_targets[target_image]
165-
marginals = graph.get_marginals(
163+
marginals = infer.get_marginals(
166164
bp.get_beliefs(
167165
bp.run_bp(
168166
bp.init(

examples/ising_model.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,14 @@
2020
import matplotlib.pyplot as plt
2121
import numpy as np
2222

23-
from pgmax.fg import graph
24-
from pgmax.groups import enumeration
25-
from pgmax.groups import variables as vgroup
23+
from pgmax import fgraph, fgroup, infer, vgroup
2624

2725
# %% [markdown]
2826
# ### Construct variable grid, initialize factor graph, and add factors
2927

3028
# %%
31-
variables = vgroup.NDVariableArray(num_states=2, shape=(50, 50))
32-
fg = graph.FactorGraph(variable_groups=variables)
29+
variables = vgroup.NDVarArray(num_states=2, shape=(50, 50))
30+
fg = fgraph.FactorGraph(variable_groups=variables)
3331

3432
variables_for_factors = []
3533
for ii in range(50):
@@ -39,7 +37,7 @@
3937
variables_for_factors.append([variables[ii, jj], variables[kk, jj]])
4038
variables_for_factors.append([variables[ii, jj], variables[ii, ll]])
4139

42-
factor_group = enumeration.PairwiseFactorGroup(
40+
factor_group = fgroup.PairwiseFactorGroup(
4341
variables_for_factors=variables_for_factors,
4442
log_potential_matrix=0.8 * np.array([[1.0, -1.0], [-1.0, 1.0]]),
4543
)
@@ -49,7 +47,7 @@
4947
# ### Run inference and visualize results
5048

5149
# %%
52-
bp = graph.BP(fg.bp_state, temperature=0)
50+
bp = infer.BP(fg.bp_state, temperature=0)
5351

5452
# %%
5553
bp_arrays = bp.init(
@@ -59,7 +57,7 @@
5957
beliefs = bp.get_beliefs(bp_arrays)
6058

6159
# %%
62-
img = graph.decode_map_states(beliefs)[variables]
60+
img = infer.decode_map_states(beliefs)[variables]
6361
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
6462
ax.imshow(img)
6563

examples/pmp_binary_deconvolution.py

+10-14
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,7 @@
2828
from scipy.special import logit
2929
from tqdm.notebook import tqdm
3030

31-
from pgmax.fg import graph
32-
from pgmax.groups import logical
33-
from pgmax.groups import variables as vgroup
31+
from pgmax import fgraph, fgroup, infer, vgroup
3432

3533

3634
# %%
@@ -117,28 +115,26 @@ def plot_images(images, display=True, nr=None):
117115
s_width = im_width - feat_width + 1
118116

119117
# Binary features
120-
W = vgroup.NDVariableArray(
121-
num_states=2, shape=(n_chan, n_feat, feat_height, feat_width)
122-
)
118+
W = vgroup.NDVarArray(num_states=2, shape=(n_chan, n_feat, feat_height, feat_width))
123119

124120
# Binary indicators of features locations
125-
S = vgroup.NDVariableArray(num_states=2, shape=(n_images, n_feat, s_height, s_width))
121+
S = vgroup.NDVarArray(num_states=2, shape=(n_images, n_feat, s_height, s_width))
126122

127123
# Auxiliary binary variables combining W and S
128-
SW = vgroup.NDVariableArray(
124+
SW = vgroup.NDVarArray(
129125
num_states=2,
130126
shape=(n_images, n_chan, im_height, im_width, n_feat, feat_height, feat_width),
131127
)
132128

133129
# Binary images obtained by convolution
134-
X = vgroup.NDVariableArray(num_states=2, shape=X_gt.shape)
130+
X = vgroup.NDVarArray(num_states=2, shape=X_gt.shape)
135131

136132
# %% [markdown]
137133
# For computation efficiency, we construct large FactorGroups instead of individual factors
138134

139135
# %%
140136
# Factor graph
141-
fg = graph.FactorGraph(variable_groups=[S, W, SW, X])
137+
fg = fgraph.FactorGraph(variable_groups=[S, W, SW, X])
142138

143139
# Define the ANDFactors
144140
variables_for_ANDFactors = []
@@ -173,7 +169,7 @@ def plot_images(images, display=True, nr=None):
173169
variables_for_ORFactors_dict[X_var].append(SW_var)
174170

175171
# Add ANDFactorGroup, which is computationally efficient
176-
AND_factor_group = logical.ANDFactorGroup(variables_for_ANDFactors)
172+
AND_factor_group = fgroup.ANDFactorGroup(variables_for_ANDFactors)
177173
fg.add_factors(AND_factor_group)
178174

179175
# Define the ORFactors
@@ -183,7 +179,7 @@ def plot_images(images, display=True, nr=None):
183179
]
184180

185181
# Add ORFactorGroup, which is computationally efficient
186-
OR_factor_group = logical.ORFactorGroup(variables_for_ORFactors)
182+
OR_factor_group = fgroup.ORFactorGroup(variables_for_ORFactors)
187183
fg.add_factors(OR_factor_group)
188184

189185
for factor_type, factor_groups in fg.factor_groups.items():
@@ -202,7 +198,7 @@ def plot_images(images, display=True, nr=None):
202198
# in the same manner does not change X, so this naturally results in multiple equivalent modes.
203199

204200
# %%
205-
bp = graph.BP(fg.bp_state, temperature=0.0)
201+
bp = infer.BP(fg.bp_state, temperature=0.0)
206202

207203
# %% [markdown]
208204
# We first compute the evidence without perturbation, similar to the PMP paper.
@@ -246,7 +242,7 @@ def plot_images(images, display=True, nr=None):
246242
)(bp_arrays)
247243

248244
beliefs = jax.vmap(bp.get_beliefs, in_axes=0, out_axes=0)(bp_arrays)
249-
map_states = graph.decode_map_states(beliefs)
245+
map_states = infer.decode_map_states(beliefs)
250246

251247
# %% [markdown]
252248
# Visualizing the MAP decoding, we see that we have 4 good random samples (one per row) from the posterior!

examples/rbm.py

+27-29
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,10 @@
2626
import matplotlib.pyplot as plt
2727
import numpy as np
2828

29-
from pgmax.fg import graph
30-
from pgmax.groups import enumeration
31-
from pgmax.groups import variables as vgroup
29+
from pgmax import fgraph, fgroup, infer, vgroup
3230

3331
# %% [markdown]
34-
# The [`pgmax.fg.graph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.html#module-pgmax.fg.graph) module contains core classes for specifying factor graphs and implementing LBP, while the [`pgmax.fg.groups`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.html#module-pgmax.fg.graph) module contains classes for specifying groups of variables/factors.
32+
# The [`pgmax.fgraph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fgraph.html#module-pgmax.fgraph) module contains classes for specifying factor graphs, the [`pgmax.fgroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fgroup.html#module-pgmax.vgroup) module contains classes for specifying groups of variables, the [`pgmax.vgroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fgroup.html#module-pgmax.fgroup) module contains classes for specifying groups of factors and the [`pgmax.infer`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.infer.html#module-pgmax.infer) module containing core functions to perform LBP.
3533
#
3634
# We next load the RBM trained in Sec. 5.5 of the [PMP paper](https://proceedings.neurips.cc/paper/2021/hash/07b1c04a30f798b5506c1ec5acfb9031-Abstract.html) on MNIST digits.
3735

@@ -47,23 +45,23 @@
4745

4846
# %%
4947
# Initialize factor graph
50-
hidden_variables = vgroup.NDVariableArray(num_states=2, shape=bh.shape)
51-
visible_variables = vgroup.NDVariableArray(num_states=2, shape=bv.shape)
52-
fg = graph.FactorGraph(variable_groups=[hidden_variables, visible_variables])
48+
hidden_variables = vgroup.NDVarArray(num_states=2, shape=bh.shape)
49+
visible_variables = vgroup.NDVarArray(num_states=2, shape=bv.shape)
50+
fg = fgraph.FactorGraph(variable_groups=[hidden_variables, visible_variables])
5351

5452
# %% [markdown]
55-
# [`NDVariableArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVariableArray.html#pgmax.fg.groups.NDVariableArray) is a convenient class for specifying a group of variables living on a multidimensional grid with the same number of states, and shares some similarities with [`numpy.ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html). The [`FactorGraph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.FactorGraph.html#pgmax.fg.graph.FactorGraph) `fg` is initialized with a set of variables, which can be either a single [`VariableGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.VariableGroup.html#pgmax.fg.groups.VariableGroup) (e.g. an [`NDVariableArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVariableArray.html#pgmax.fg.groups.NDVariableArray)), or a list/dictionary of [`VariableGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.VariableGroup.html#pgmax.fg.groups.VariableGroup)s. Once initialized, the set of variables in `fg` is fixed and cannot be changed.
53+
# [`NDVarArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVarArray.html#pgmax.fg.groups.NDVarArray) is a convenient class for specifying a group of variables living on a multidimensional grid with the same number of states, and shares some similarities with [`numpy.ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html). The [`FactorGraph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.FactorGraph.html#pgmax.fg.graph.FactorGraph) `fg` is initialized with a set of variables, which can be either a single [`VarGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.VarGroup.html#pgmax.fg.groups.VarGroup) (e.g. an [`NDVarArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVarArray.html#pgmax.fg.groups.NDVarArray)), or a list of [`VarGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.VarGroup.html#pgmax.fg.groups.VarGroup)s. Once initialized, the set of variables in `fg` is fixed and cannot be changed.
5654
#
5755
# After initialization, `fg` does not have any factors. PGMax supports imperatively adding factors to a [`FactorGraph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.FactorGraph.html#pgmax.fg.graph.FactorGraph). We can add the unary and pairwise factors by grouping them using [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup)
5856

5957
# %%
6058
# Create unary factors
61-
hidden_unaries = enumeration.EnumerationFactorGroup(
59+
hidden_unaries = fgroup.EnumFactorGroup(
6260
variables_for_factors=[[hidden_variables[ii]] for ii in range(bh.shape[0])],
6361
factor_configs=np.arange(2)[:, None],
6462
log_potentials=np.stack([np.zeros_like(bh), bh], axis=1),
6563
)
66-
visible_unaries = enumeration.EnumerationFactorGroup(
64+
visible_unaries = fgroup.EnumFactorGroup(
6765
variables_for_factors=[[visible_variables[jj]] for jj in range(bv.shape[0])],
6866
factor_configs=np.arange(2)[:, None],
6967
log_potentials=np.stack([np.zeros_like(bv), bv], axis=1),
@@ -78,7 +76,7 @@
7876
for ii in range(bh.shape[0])
7977
for jj in range(bv.shape[0])
8078
]
81-
pairwise_factors = enumeration.PairwiseFactorGroup(
79+
pairwise_factors = fgroup.PairwiseFactorGroup(
8280
variables_for_factors=variables_for_factors,
8381
log_potential_matrix=log_potential_matrix,
8482
)
@@ -88,67 +86,67 @@
8886

8987

9088
# %% [markdown]
91-
# PGMax implements convenient and computationally efficient [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup) for representing groups of similar factors. The code above makes use of [`EnumerationFactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.EnumerationFactorGroup.html#pgmax.fg.groups.EnumerationFactorGroup) and [`PairwiseFactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.PairwiseFactorGroup.html#pgmax.fg.groups.PairwiseFactorGroup).
89+
# PGMax implements convenient and computationally efficient [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup) for representing groups of similar factors. The code above makes use of [`EnumFactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.EnumFactorGroup.html#pgmax.fg.groups.EnumFactorGroup) and [`PairwiseFactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.PairwiseFactorGroup.html#pgmax.fg.groups.PairwiseFactorGroup).
9290
#
9391
# A [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup) takes as argument `variables_for_factors` which is a list of lists of the variables involved in the different factors, and additional arguments specific to each [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup) (e.g. `factor_configs` or `log_potential_matrix` here).
9492
#
95-
# In this example, since we construct `fg` with variables `hidden_variables` and `visible_variables`, which are both [`NDVariableArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVariableArray.html#pgmax.fg.groups.NDVariableArray)s, we can refer to the `ii`th hidden variable as `hidden_variables[ii]` and the `jj`th visible variable as `visible_variables[jj]`.
93+
# In this example, since we construct `fg` with variables `hidden_variables` and `visible_variables`, which are both [`NDVarArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVarArray.html#pgmax.fg.groups.NDVarArray)s, we can refer to the `ii`th hidden variable as `hidden_variables[ii]` and the `jj`th visible variable as `visible_variables[jj]`.
9694
#
9795
# An alternative way of creating the above factors is to add them iteratively without building the [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup)s as below. This approach is not recommended as it is not computationally efficient.
9896
# ~~~python
99-
# from pgmax.factors import enumeration as enumeration_factor
97+
# from pgmax import factor
10098
# import itertools
10199
# from tqdm import tqdm
102100
#
103101
# # Add unary factors
104102
# for ii in range(bh.shape[0]):
105-
# factor = enumeration_factor.EnumerationFactor(
103+
# unary_factor = factor.EnumFactor(
106104
# variables=[hidden_variables[ii]],
107105
# factor_configs=np.arange(2)[:, None],
108106
# log_potentials=np.array([0, bh[ii]]),
109107
# )
110-
# fg.add_factors(factor)
108+
# fg.add_factors(unary_factor)
111109
#
112110
# for jj in range(bv.shape[0]):
113-
# factor = enumeration_factor.EnumerationFactor(
111+
# unary_factor = factor.EnumFactor(
114112
# variables=[visible_variables[jj]],
115113
# factor_configs=np.arange(2)[:, None],
116114
# log_potentials=np.array([0, bv[jj]]),
117115
# )
118-
# fg.add_factors(factor)
116+
# fg.add_factors(unary_factor)
119117
#
120118
# # Add pairwise factors
121119
# factor_configs = np.array(list(itertools.product(np.arange(2), repeat=2)))
122120
# for ii in tqdm(range(bh.shape[0])):
123121
# for jj in range(bv.shape[0]):
124-
# factor = enumeration_factor.EnumerationFactor(
122+
# pairwise_factor = factor.EnumFactor(
125123
# variables=[hidden_variables[ii], visible_variables[jj]],
126124
# factor_configs=factor_configs,
127125
# log_potentials=np.array([0, 0, 0, W[ii, jj]]),
128126
# )
129-
# fg.add_factors(factor)
127+
# fg.add_factors(pairwise_factor)
130128
# ~~~
131129
#
132130
# Once we have added the factors, we can run max-product LBP and get MAP decoding by
133131
# ~~~python
134-
# bp = graph.BP(fg.bp_state, temperature=0.0)
132+
# bp = infer.BP(fg.bp_state, temperature=0.0)
135133
# bp_arrays = bp.run_bp(bp.init(), num_iters=100, damping=0.5)
136134
# beliefs = bp.get_beliefs(bp_arrays)
137-
# map_states = graph.decode_map_states(beliefs)
135+
# map_states = infer.decode_map_states(beliefs)
138136
# ~~~
139137
# and run sum-product LBP and get estimated marginals by
140138
# ~~~python
141-
# bp = graph.BP(fg.bp_state, temperature=1.0)
139+
# bp = infer.BP(fg.bp_state, temperature=1.0)
142140
# bp_arrays = bp.run_bp(bp.init(), num_iters=100, damping=0.5)
143141
# beliefs = bp.get_beliefs(bp_arrays)
144-
# marginals = graph.get_marginals(beliefs)
142+
# marginals = infer.get_marginals(beliefs)
145143
# ~~~
146144
# More generally, PGMax implements LBP with temperature, with `temperature=0.0` and `temperature=1.0` corresponding to the commonly used max/sum-product LBP respectively.
147145
#
148146
# Now we are ready to demonstrate PMP sampling from RBM. PMP perturbs the model with [Gumbel](https://numpy.org/doc/stable/reference/random/generated/numpy.random.gumbel.html) unary potentials, and draws a sample from the RBM as the MAP decoding from running max-product LBP on the perturbed model
149147

150148
# %%
151-
bp = graph.BP(fg.bp_state, temperature=0.0)
149+
bp = infer.BP(fg.bp_state, temperature=0.0)
152150

153151
# %%
154152
bp_arrays = bp.init(
@@ -168,17 +166,17 @@
168166
# %%
169167
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
170168
ax.imshow(
171-
graph.decode_map_states(beliefs)[visible_variables].copy().reshape((28, 28)),
169+
infer.decode_map_states(beliefs)[visible_variables].copy().reshape((28, 28)),
172170
cmap="gray",
173171
)
174172
ax.axis("off")
175173

176174
# %% [markdown]
177175
# PGMax adopts a functional interface for implementing LBP: running LBP in PGMax starts with
178176
# ~~~python
179-
# bp = graph.BP(fg.bp_state, temperature=T)
177+
# bp = infer.BP(fg.bp_state, temperature=T)
180178
# ~~~
181-
# where the arguments of the `bp` are several useful functions to run LBP. In particular, `bp.init`, `bp.run_bp`, `bp.get_beliefs` are pure functions with no side-effects. This design choice means that we can easily apply JAX transformations like `jit`/`vmap`/`grad`, etc., to these functions, and additionally allows PGMax to seamlessly interact with other packages in the rapidly growing JAX ecosystem (see [here](https://deepmind.com/blog/article/using-jax-to-accelerate-our-research) and [here](https://github.com/n2cholas/awesome-jax)).
179+
# where the arguments of the `this_bp` are several useful functions to run LBP. In particular, `bp.init`, `bp.run_bp`, `bp.get_beliefs` are pure functions with no side-effects. This design choice means that we can easily apply JAX transformations like `jit`/`vmap`/`grad`, etc., to these functions, and additionally allows PGMax to seamlessly interact with other packages in the rapidly growing JAX ecosystem (see [here](https://deepmind.com/blog/article/using-jax-to-accelerate-our-research) and [here](https://github.com/n2cholas/awesome-jax)).
182180
#
183181
# As an example of applying `jax.vmap` to `bp.init`/`bp.run_bp`/`bp.get_beliefs` to process a batch of samples/models in parallel, instead of drawing one sample at a time as above, we can draw a batch of samples in parallel as follows:
184182

@@ -197,7 +195,7 @@
197195
)(bp_arrays)
198196

199197
beliefs = jax.vmap(bp.get_beliefs, in_axes=0, out_axes=0)(bp_arrays)
200-
map_states = graph.decode_map_states(beliefs)
198+
map_states = infer.decode_map_states(beliefs)
201199

202200
# %% [markdown]
203201
# Visualizing the MAP decodings, we see that we have sampled 10 MNIST digits in parallel!

0 commit comments

Comments
 (0)