|
26 | 26 | import matplotlib.pyplot as plt
|
27 | 27 | import numpy as np
|
28 | 28 |
|
29 |
| -from pgmax.fg import graph |
30 |
| -from pgmax.groups import enumeration |
31 |
| -from pgmax.groups import variables as vgroup |
| 29 | +from pgmax import fgraph, fgroup, infer, vgroup |
32 | 30 |
|
33 | 31 | # %% [markdown]
|
34 |
| -# The [`pgmax.fg.graph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.html#module-pgmax.fg.graph) module contains core classes for specifying factor graphs and implementing LBP, while the [`pgmax.fg.groups`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.html#module-pgmax.fg.graph) module contains classes for specifying groups of variables/factors. |
| 32 | +# The [`pgmax.fgraph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fgraph.html#module-pgmax.fgraph) module contains classes for specifying factor graphs, the [`pgmax.fgroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fgroup.html#module-pgmax.vgroup) module contains classes for specifying groups of variables, the [`pgmax.vgroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fgroup.html#module-pgmax.fgroup) module contains classes for specifying groups of factors and the [`pgmax.infer`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.infer.html#module-pgmax.infer) module containing core functions to perform LBP. |
35 | 33 | #
|
36 | 34 | # We next load the RBM trained in Sec. 5.5 of the [PMP paper](https://proceedings.neurips.cc/paper/2021/hash/07b1c04a30f798b5506c1ec5acfb9031-Abstract.html) on MNIST digits.
|
37 | 35 |
|
|
47 | 45 |
|
48 | 46 | # %%
|
49 | 47 | # Initialize factor graph
|
50 |
| -hidden_variables = vgroup.NDVariableArray(num_states=2, shape=bh.shape) |
51 |
| -visible_variables = vgroup.NDVariableArray(num_states=2, shape=bv.shape) |
52 |
| -fg = graph.FactorGraph(variable_groups=[hidden_variables, visible_variables]) |
| 48 | +hidden_variables = vgroup.NDVarArray(num_states=2, shape=bh.shape) |
| 49 | +visible_variables = vgroup.NDVarArray(num_states=2, shape=bv.shape) |
| 50 | +fg = fgraph.FactorGraph(variable_groups=[hidden_variables, visible_variables]) |
53 | 51 |
|
54 | 52 | # %% [markdown]
|
55 |
| -# [`NDVariableArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVariableArray.html#pgmax.fg.groups.NDVariableArray) is a convenient class for specifying a group of variables living on a multidimensional grid with the same number of states, and shares some similarities with [`numpy.ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html). The [`FactorGraph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.FactorGraph.html#pgmax.fg.graph.FactorGraph) `fg` is initialized with a set of variables, which can be either a single [`VariableGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.VariableGroup.html#pgmax.fg.groups.VariableGroup) (e.g. an [`NDVariableArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVariableArray.html#pgmax.fg.groups.NDVariableArray)), or a list/dictionary of [`VariableGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.VariableGroup.html#pgmax.fg.groups.VariableGroup)s. Once initialized, the set of variables in `fg` is fixed and cannot be changed. |
| 53 | +# [`NDVarArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVarArray.html#pgmax.fg.groups.NDVarArray) is a convenient class for specifying a group of variables living on a multidimensional grid with the same number of states, and shares some similarities with [`numpy.ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html). The [`FactorGraph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.FactorGraph.html#pgmax.fg.graph.FactorGraph) `fg` is initialized with a set of variables, which can be either a single [`VarGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.VarGroup.html#pgmax.fg.groups.VarGroup) (e.g. an [`NDVarArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVarArray.html#pgmax.fg.groups.NDVarArray)), or a list of [`VarGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.VarGroup.html#pgmax.fg.groups.VarGroup)s. Once initialized, the set of variables in `fg` is fixed and cannot be changed. |
56 | 54 | #
|
57 | 55 | # After initialization, `fg` does not have any factors. PGMax supports imperatively adding factors to a [`FactorGraph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.FactorGraph.html#pgmax.fg.graph.FactorGraph). We can add the unary and pairwise factors by grouping them using [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup)
|
58 | 56 |
|
59 | 57 | # %%
|
60 | 58 | # Create unary factors
|
61 |
| -hidden_unaries = enumeration.EnumerationFactorGroup( |
| 59 | +hidden_unaries = fgroup.EnumFactorGroup( |
62 | 60 | variables_for_factors=[[hidden_variables[ii]] for ii in range(bh.shape[0])],
|
63 | 61 | factor_configs=np.arange(2)[:, None],
|
64 | 62 | log_potentials=np.stack([np.zeros_like(bh), bh], axis=1),
|
65 | 63 | )
|
66 |
| -visible_unaries = enumeration.EnumerationFactorGroup( |
| 64 | +visible_unaries = fgroup.EnumFactorGroup( |
67 | 65 | variables_for_factors=[[visible_variables[jj]] for jj in range(bv.shape[0])],
|
68 | 66 | factor_configs=np.arange(2)[:, None],
|
69 | 67 | log_potentials=np.stack([np.zeros_like(bv), bv], axis=1),
|
|
78 | 76 | for ii in range(bh.shape[0])
|
79 | 77 | for jj in range(bv.shape[0])
|
80 | 78 | ]
|
81 |
| -pairwise_factors = enumeration.PairwiseFactorGroup( |
| 79 | +pairwise_factors = fgroup.PairwiseFactorGroup( |
82 | 80 | variables_for_factors=variables_for_factors,
|
83 | 81 | log_potential_matrix=log_potential_matrix,
|
84 | 82 | )
|
|
88 | 86 |
|
89 | 87 |
|
90 | 88 | # %% [markdown]
|
91 |
| -# PGMax implements convenient and computationally efficient [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup) for representing groups of similar factors. The code above makes use of [`EnumerationFactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.EnumerationFactorGroup.html#pgmax.fg.groups.EnumerationFactorGroup) and [`PairwiseFactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.PairwiseFactorGroup.html#pgmax.fg.groups.PairwiseFactorGroup). |
| 89 | +# PGMax implements convenient and computationally efficient [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup) for representing groups of similar factors. The code above makes use of [`EnumFactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.EnumFactorGroup.html#pgmax.fg.groups.EnumFactorGroup) and [`PairwiseFactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.PairwiseFactorGroup.html#pgmax.fg.groups.PairwiseFactorGroup). |
92 | 90 | #
|
93 | 91 | # A [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup) takes as argument `variables_for_factors` which is a list of lists of the variables involved in the different factors, and additional arguments specific to each [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup) (e.g. `factor_configs` or `log_potential_matrix` here).
|
94 | 92 | #
|
95 |
| -# In this example, since we construct `fg` with variables `hidden_variables` and `visible_variables`, which are both [`NDVariableArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVariableArray.html#pgmax.fg.groups.NDVariableArray)s, we can refer to the `ii`th hidden variable as `hidden_variables[ii]` and the `jj`th visible variable as `visible_variables[jj]`. |
| 93 | +# In this example, since we construct `fg` with variables `hidden_variables` and `visible_variables`, which are both [`NDVarArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVarArray.html#pgmax.fg.groups.NDVarArray)s, we can refer to the `ii`th hidden variable as `hidden_variables[ii]` and the `jj`th visible variable as `visible_variables[jj]`. |
96 | 94 | #
|
97 | 95 | # An alternative way of creating the above factors is to add them iteratively without building the [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup)s as below. This approach is not recommended as it is not computationally efficient.
|
98 | 96 | # ~~~python
|
99 |
| -# from pgmax.factors import enumeration as enumeration_factor |
| 97 | +# from pgmax import factor |
100 | 98 | # import itertools
|
101 | 99 | # from tqdm import tqdm
|
102 | 100 | #
|
103 | 101 | # # Add unary factors
|
104 | 102 | # for ii in range(bh.shape[0]):
|
105 |
| -# factor = enumeration_factor.EnumerationFactor( |
| 103 | +# unary_factor = factor.EnumFactor( |
106 | 104 | # variables=[hidden_variables[ii]],
|
107 | 105 | # factor_configs=np.arange(2)[:, None],
|
108 | 106 | # log_potentials=np.array([0, bh[ii]]),
|
109 | 107 | # )
|
110 |
| -# fg.add_factors(factor) |
| 108 | +# fg.add_factors(unary_factor) |
111 | 109 | #
|
112 | 110 | # for jj in range(bv.shape[0]):
|
113 |
| -# factor = enumeration_factor.EnumerationFactor( |
| 111 | +# unary_factor = factor.EnumFactor( |
114 | 112 | # variables=[visible_variables[jj]],
|
115 | 113 | # factor_configs=np.arange(2)[:, None],
|
116 | 114 | # log_potentials=np.array([0, bv[jj]]),
|
117 | 115 | # )
|
118 |
| -# fg.add_factors(factor) |
| 116 | +# fg.add_factors(unary_factor) |
119 | 117 | #
|
120 | 118 | # # Add pairwise factors
|
121 | 119 | # factor_configs = np.array(list(itertools.product(np.arange(2), repeat=2)))
|
122 | 120 | # for ii in tqdm(range(bh.shape[0])):
|
123 | 121 | # for jj in range(bv.shape[0]):
|
124 |
| -# factor = enumeration_factor.EnumerationFactor( |
| 122 | +# pairwise_factor = factor.EnumFactor( |
125 | 123 | # variables=[hidden_variables[ii], visible_variables[jj]],
|
126 | 124 | # factor_configs=factor_configs,
|
127 | 125 | # log_potentials=np.array([0, 0, 0, W[ii, jj]]),
|
128 | 126 | # )
|
129 |
| -# fg.add_factors(factor) |
| 127 | +# fg.add_factors(pairwise_factor) |
130 | 128 | # ~~~
|
131 | 129 | #
|
132 | 130 | # Once we have added the factors, we can run max-product LBP and get MAP decoding by
|
133 | 131 | # ~~~python
|
134 |
| -# bp = graph.BP(fg.bp_state, temperature=0.0) |
| 132 | +# bp = infer.BP(fg.bp_state, temperature=0.0) |
135 | 133 | # bp_arrays = bp.run_bp(bp.init(), num_iters=100, damping=0.5)
|
136 | 134 | # beliefs = bp.get_beliefs(bp_arrays)
|
137 |
| -# map_states = graph.decode_map_states(beliefs) |
| 135 | +# map_states = infer.decode_map_states(beliefs) |
138 | 136 | # ~~~
|
139 | 137 | # and run sum-product LBP and get estimated marginals by
|
140 | 138 | # ~~~python
|
141 |
| -# bp = graph.BP(fg.bp_state, temperature=1.0) |
| 139 | +# bp = infer.BP(fg.bp_state, temperature=1.0) |
142 | 140 | # bp_arrays = bp.run_bp(bp.init(), num_iters=100, damping=0.5)
|
143 | 141 | # beliefs = bp.get_beliefs(bp_arrays)
|
144 |
| -# marginals = graph.get_marginals(beliefs) |
| 142 | +# marginals = infer.get_marginals(beliefs) |
145 | 143 | # ~~~
|
146 | 144 | # More generally, PGMax implements LBP with temperature, with `temperature=0.0` and `temperature=1.0` corresponding to the commonly used max/sum-product LBP respectively.
|
147 | 145 | #
|
148 | 146 | # Now we are ready to demonstrate PMP sampling from RBM. PMP perturbs the model with [Gumbel](https://numpy.org/doc/stable/reference/random/generated/numpy.random.gumbel.html) unary potentials, and draws a sample from the RBM as the MAP decoding from running max-product LBP on the perturbed model
|
149 | 147 |
|
150 | 148 | # %%
|
151 |
| -bp = graph.BP(fg.bp_state, temperature=0.0) |
| 149 | +bp = infer.BP(fg.bp_state, temperature=0.0) |
152 | 150 |
|
153 | 151 | # %%
|
154 | 152 | bp_arrays = bp.init(
|
|
168 | 166 | # %%
|
169 | 167 | fig, ax = plt.subplots(1, 1, figsize=(10, 10))
|
170 | 168 | ax.imshow(
|
171 |
| - graph.decode_map_states(beliefs)[visible_variables].copy().reshape((28, 28)), |
| 169 | + infer.decode_map_states(beliefs)[visible_variables].copy().reshape((28, 28)), |
172 | 170 | cmap="gray",
|
173 | 171 | )
|
174 | 172 | ax.axis("off")
|
175 | 173 |
|
176 | 174 | # %% [markdown]
|
177 | 175 | # PGMax adopts a functional interface for implementing LBP: running LBP in PGMax starts with
|
178 | 176 | # ~~~python
|
179 |
| -# bp = graph.BP(fg.bp_state, temperature=T) |
| 177 | +# bp = infer.BP(fg.bp_state, temperature=T) |
180 | 178 | # ~~~
|
181 |
| -# where the arguments of the `bp` are several useful functions to run LBP. In particular, `bp.init`, `bp.run_bp`, `bp.get_beliefs` are pure functions with no side-effects. This design choice means that we can easily apply JAX transformations like `jit`/`vmap`/`grad`, etc., to these functions, and additionally allows PGMax to seamlessly interact with other packages in the rapidly growing JAX ecosystem (see [here](https://deepmind.com/blog/article/using-jax-to-accelerate-our-research) and [here](https://github.com/n2cholas/awesome-jax)). |
| 179 | +# where the arguments of the `this_bp` are several useful functions to run LBP. In particular, `bp.init`, `bp.run_bp`, `bp.get_beliefs` are pure functions with no side-effects. This design choice means that we can easily apply JAX transformations like `jit`/`vmap`/`grad`, etc., to these functions, and additionally allows PGMax to seamlessly interact with other packages in the rapidly growing JAX ecosystem (see [here](https://deepmind.com/blog/article/using-jax-to-accelerate-our-research) and [here](https://github.com/n2cholas/awesome-jax)). |
182 | 180 | #
|
183 | 181 | # As an example of applying `jax.vmap` to `bp.init`/`bp.run_bp`/`bp.get_beliefs` to process a batch of samples/models in parallel, instead of drawing one sample at a time as above, we can draw a batch of samples in parallel as follows:
|
184 | 182 |
|
|
197 | 195 | )(bp_arrays)
|
198 | 196 |
|
199 | 197 | beliefs = jax.vmap(bp.get_beliefs, in_axes=0, out_axes=0)(bp_arrays)
|
200 |
| -map_states = graph.decode_map_states(beliefs) |
| 198 | +map_states = infer.decode_map_states(beliefs) |
201 | 199 |
|
202 | 200 | # %% [markdown]
|
203 | 201 | # Visualizing the MAP decodings, we see that we have sampled 10 MNIST digits in parallel!
|
|
0 commit comments