Skip to content
This repository was archived by the owner on Dec 5, 2024. It is now read-only.

Commit a6be2a6

Browse files
Includes version of Heretic Model in PGMax! (#47)
* includes weights from @lazarox trained heretic model * includes support for non-uniform potentials * includes fully-complete heretic_example * updates decode_map_states * function now outputs mapping from keys (instead of Variables) to MAP states * sanity_check and heretic examples updates to call this new function
1 parent b9feffc commit a6be2a6

File tree

5 files changed

+263
-12
lines changed

5 files changed

+263
-12
lines changed
Binary file not shown.

examples/heretic_example.py

+245
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
# ---
2+
# jupyter:
3+
# jupytext:
4+
# formats: ipynb,py:percent
5+
# text_representation:
6+
# extension: .py
7+
# format_name: percent
8+
# format_version: '1.3'
9+
# jupytext_version: 1.11.4
10+
# kernelspec:
11+
# display_name: 'Python 3.8.5 64-bit (''pgmax-JcKb81GE-py3.8'': poetry)'
12+
# name: python3
13+
# ---
14+
15+
# %%
16+
# %matplotlib inline
17+
# fmt: off
18+
19+
# Standard Package Imports
20+
import matplotlib.pyplot as plt # isort:skip
21+
import numpy as np # isort:skip
22+
import jax # isort:skip
23+
import jax.numpy as jnp # isort:skip
24+
from typing import Any, Tuple, List # isort:skip
25+
from timeit import default_timer as timer # isort:skip
26+
27+
# Custom Imports
28+
import pgmax.fg.groups as groups # isort:skip
29+
import pgmax.fg.graph as graph # isort:skip
30+
31+
# fmt: on
32+
33+
# %% [markdown]
34+
# # Setup Variables
35+
36+
# %%
37+
# Define some global constants
38+
im_size = (30, 30)
39+
prng_key = jax.random.PRNGKey(42)
40+
41+
# Instantiate all the Variables in the factor graph via VariableGroups
42+
pixel_vars = groups.NDVariableArray(3, im_size)
43+
hidden_vars = groups.NDVariableArray(
44+
17, (im_size[0] - 2, im_size[1] - 2)
45+
) # Each hidden var is connected to a 3x3 patch of pixel vars
46+
composite_vargroup = groups.CompositeVariableGroup((pixel_vars, hidden_vars))
47+
48+
# %% [markdown]
49+
# # Load Trained Weights And Setup Evidence
50+
51+
# %%
52+
# Load weights and create evidence (taken directly from @lazarox's code)
53+
crbm_weights = np.load("example_data/crbm_mnist_weights_surfaces_pmap002.npz")
54+
W_orig, bX, bH = crbm_weights["W"], crbm_weights["bX"], crbm_weights["bH"]
55+
n_samples = 1
56+
T = 1
57+
58+
im_height, im_width = im_size
59+
n_cat_X, n_cat_H, f_s = W_orig.shape[:3]
60+
W = W_orig.reshape(1, n_cat_X, n_cat_H, f_s, f_s, 1, 1)
61+
bXn = jnp.zeros((n_samples, n_cat_X, 1, 1, 1, im_height, im_width))
62+
border = jnp.zeros((1, n_cat_X, 1, 1, 1) + im_size)
63+
border = border.at[:, 1:, :, :, :, :1, :].set(-10)
64+
border = border.at[:, 1:, :, :, :, -1:, :].set(-10)
65+
border = border.at[:, 1:, :, :, :, :, :1].set(-10)
66+
border = border.at[:, 1:, :, :, :, :, -1:].set(-10)
67+
bXn = bXn + border
68+
rng, rng_input = jax.random.split(prng_key)
69+
rnX = jax.random.gumbel(
70+
rng_input, shape=(n_samples, n_cat_X, 1, 1, 1, im_height, im_width)
71+
)
72+
bXn = bXn + bX[None, :, :, :, :, None, None] + T * rnX
73+
rng, rng_input = jax.random.split(rng)
74+
rnH = jax.random.gumbel(
75+
rng_input,
76+
shape=(n_samples, 1, n_cat_H, 1, 1, im_height - f_s + 1, im_width - f_s + 1),
77+
)
78+
bHn = bH[None, :, :, :, :, None, None] + T * rnH
79+
80+
bXn_evidence = bXn.reshape((3, 30, 30))
81+
bXn_evidence = bXn_evidence.swapaxes(0, 1)
82+
bXn_evidence = bXn_evidence.swapaxes(1, 2)
83+
bHn_evidence = bHn.reshape((17, 28, 28))
84+
bHn_evidence = bHn_evidence.swapaxes(0, 1)
85+
bHn_evidence = bHn_evidence.swapaxes(1, 2)
86+
87+
88+
# %% [markdown]
89+
# # Create FactorGraph and Assign Evidence
90+
91+
# %%
92+
# Create the factor graph
93+
fg = graph.FactorGraph((pixel_vars, hidden_vars))
94+
95+
# Assign evidence to pixel vars
96+
fg.set_evidence(0, np.array(bXn_evidence))
97+
fg.set_evidence(1, np.array(bHn_evidence))
98+
99+
100+
# %% [markdown]
101+
# # Add all Factors to graph via constructing FactorGroups
102+
103+
# %%
104+
def binary_connected_variables(
105+
num_hidden_rows, num_hidden_cols, kernel_row, kernel_col
106+
):
107+
ret_list: List[List[Tuple[Any, ...]]] = []
108+
for h_row in range(num_hidden_rows):
109+
for h_col in range(num_hidden_cols):
110+
ret_list.append(
111+
[
112+
(1, h_row, h_col),
113+
(0, h_row + kernel_row, h_col + kernel_col),
114+
]
115+
)
116+
return ret_list
117+
118+
119+
W_pot = W_orig.swapaxes(0, 1)
120+
for k_row in range(3):
121+
for k_col in range(3):
122+
fg.add_factors(
123+
factor_factory=groups.PairwiseFactorGroup,
124+
connected_var_keys=binary_connected_variables(28, 28, k_row, k_col),
125+
log_potential_matrix=W_pot[:, :, k_row, k_col],
126+
)
127+
128+
# %% [markdown]
129+
# # Construct Initial Messages
130+
131+
# %%
132+
133+
134+
def custom_flatten_ordering(Mdown, Mup):
135+
flat_idx = 0
136+
flat_Mdown = Mdown.flatten()
137+
flat_Mup = Mup.flatten()
138+
flattened_arr = np.zeros(
139+
(flat_Mdown.shape[0] + flat_Mup.shape[0]),
140+
)
141+
for kernel_row in range(Mdown.shape[1]):
142+
for kernel_col in range(Mdown.shape[2]):
143+
for row in range(Mdown.shape[3]):
144+
for col in range(Mdown.shape[4]):
145+
flattened_arr[flat_idx : flat_idx + Mup.shape[0]] = Mup[
146+
:, kernel_row, kernel_col, row, col
147+
]
148+
flat_idx += Mup.shape[0]
149+
flattened_arr[flat_idx : flat_idx + Mdown.shape[0]] = Mdown[
150+
:, kernel_row, kernel_col, row, col
151+
]
152+
flat_idx += Mdown.shape[0]
153+
return flattened_arr
154+
155+
156+
# NOTE: This block only works because it exploits knowledge about the order in which the flat message array is constructed within PGMax.
157+
# Normal users won't have this...
158+
159+
# Create initial messages using bXn and bHn messages from
160+
# features to pixels (taken directly from @lazarox's code)
161+
rng, rng_input = jax.random.split(rng)
162+
Mdown = jnp.zeros(
163+
(n_samples, n_cat_X, 1, f_s, f_s, im_height - f_s + 1, im_width - f_s + 1)
164+
)
165+
Mup = jnp.zeros(
166+
(n_samples, 1, n_cat_H, f_s, f_s, im_height - f_s + 1, im_width - f_s + 1)
167+
)
168+
Mdown = Mdown - bXn[:, :, :, :, :, 1:-1, 1:-1] / f_s ** 2
169+
Mup = Mup - bHn / f_s ** 2
170+
171+
# init_weights = np.load("init_weights_mnist_surfaces_pmap002.npz")
172+
# Mdown, Mup = init_weights["Mdown"], init_weights["Mup"]
173+
# reshaped_Mdown = Mdown.reshape(3, 3, 3, 30, 30)
174+
# reshaped_Mdown = reshaped_Mdown[:,:,:,1:-1, 1:-1]
175+
reshaped_Mdown = Mdown.reshape(3, 3, 3, 28, 28)
176+
reshaped_Mup = Mup.reshape(17, 3, 3, 28, 28)
177+
178+
init_msgs = jax.device_put(
179+
custom_flatten_ordering(np.array(reshaped_Mdown), np.array(reshaped_Mup))
180+
)
181+
182+
# %% [markdown]
183+
# # Run Belief Propagation and Retrieve MAP Estimate
184+
185+
# %%
186+
# Run BP
187+
bp_start_time = timer()
188+
final_msgs = fg.run_bp(
189+
500,
190+
0.5,
191+
init_msgs=init_msgs,
192+
)
193+
bp_end_time = timer()
194+
print(f"time taken for bp {bp_end_time - bp_start_time}")
195+
196+
# Run inference and convert result to human-readable data structure
197+
data_writeback_start_time = timer()
198+
map_message_dict = fg.decode_map_states(
199+
final_msgs,
200+
)
201+
data_writeback_end_time = timer()
202+
print(
203+
f"time taken for data conversion of inference result {data_writeback_end_time - data_writeback_start_time}"
204+
)
205+
206+
207+
# %% [markdown]
208+
# # Plot Results
209+
210+
# %%
211+
# Viz function from @lazarox's code
212+
def plot_images(images):
213+
n_images, H, W = images.shape
214+
images = images - images.min()
215+
images /= images.max() + 1e-10
216+
217+
nr = nc = np.ceil(np.sqrt(n_images)).astype(int)
218+
big_image = np.ones(((H + 1) * nr + 1, (W + 1) * nc + 1, 3))
219+
big_image[..., :2] = 0
220+
im = 0
221+
for r in range(nr):
222+
for c in range(nc):
223+
if im < n_images:
224+
big_image[
225+
(H + 1) * r + 1 : (H + 1) * r + 1 + H,
226+
(W + 1) * c + 1 : (W + 1) * c + 1 + W,
227+
:,
228+
] = images[im, :, :, None]
229+
im += 1
230+
231+
plt.figure(figsize=(10, 10))
232+
plt.imshow(big_image, interpolation="none")
233+
234+
235+
# %%
236+
img_arr = np.zeros((1, im_size[0], im_size[1]))
237+
238+
for row in range(im_size[0]):
239+
for col in range(im_size[1]):
240+
img_val = float(map_message_dict[0, row, col])
241+
if img_val == 2.0:
242+
img_val = 0.4
243+
img_arr[0, row, col] = img_val * 1.0
244+
245+
plot_images(img_arr)

examples/sanity_check_example.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -365,20 +365,18 @@ def create_valid_suppression_config_arr(suppression_diameter):
365365
for row in range(M):
366366
for col in range(N):
367367
try:
368-
bp_values[i, row, col] = map_message_dict[
369-
composite_grid_group["grid_vars", i, row, col]
370-
]
368+
bp_values[i, row, col] = map_message_dict["grid_vars", i, row, col]
371369
bu_evidence[i, row, col, :] = grid_evidence_arr[i, row, col]
372-
except ValueError:
370+
except KeyError:
373371
try:
374372
bp_values[i, row, col] = map_message_dict[
375-
composite_grid_group["additional_vars", i, row, col]
373+
"additional_vars", i, row, col
376374
]
377375
bu_evidence[i, row, col, :] = additional_vars_evidence_dict[
378376
(i, row, col)
379377
]
380378

381-
except ValueError:
379+
except KeyError:
382380
pass
383381

384382

pgmax/fg/graph.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ def message_passing_step(msgs, _):
305305

306306
return msgs_after_bp
307307

308-
def decode_map_states(self, msgs: jnp.ndarray) -> Dict[nodes.Variable, int]:
308+
def decode_map_states(self, msgs: jnp.ndarray) -> Dict[Tuple[Any, ...], int]:
309309
"""Function to computes the output of MAP inference on input messages.
310310
311311
The final states are computed based on evidence obtained from the self.get_evidence
@@ -316,16 +316,17 @@ def decode_map_states(self, msgs: jnp.ndarray) -> Dict[nodes.Variable, int]:
316316
upon
317317
318318
Returns:
319-
a dictionary mapping variables to their MAP state
319+
a dictionary mapping each variable key to the MAP states of the corresponding variable
320320
"""
321321
var_states_for_edges = jax.device_put(self.wiring.var_states_for_edges)
322322
evidence = jax.device_put(self.evidence)
323323
final_var_states = evidence.at[var_states_for_edges].add(msgs)
324-
var_to_map_dict = {}
324+
var_key_to_map_dict: Dict[Tuple[Any, ...], int] = {}
325325
final_var_states_np = np.array(final_var_states)
326-
for var in self._composite_variable_group.variables:
326+
for var_key in self._composite_variable_group.keys:
327+
var = self._composite_variable_group[var_key]
327328
start_index = self._vars_to_starts[var]
328-
var_to_map_dict[var] = np.argmax(
329+
var_key_to_map_dict[var_key] = np.argmax(
329330
final_var_states_np[start_index : start_index + var.num_states]
330331
)
331-
return var_to_map_dict
332+
return var_key_to_map_dict

pgmax/fg/groups.py

+7
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,13 @@ def get_vars_to_evidence(
228228

229229
@cached_property
230230
def container_keys(self) -> Tuple:
231+
"""Function to get keys referring to the variable groups within this
232+
CompositeVariableGroup.
233+
234+
Returns:
235+
a tuple of the keys referring to the variable groups within this
236+
CompositeVariableGroup.
237+
"""
231238
if isinstance(self.variable_group_container, Mapping):
232239
container_keys = tuple(self.variable_group_container.keys())
233240
else:

0 commit comments

Comments
 (0)