Skip to content
This repository was archived by the owner on Dec 5, 2024. It is now read-only.

Commit dfc7535

Browse files
Variables refactor (#136)
* Rewrite NDVariableArray * Falke8 * Test + HashableDict * Minbor * Variables as tuple + Remove BPOuputs/HashableDict * Start tests + mypy * Tests passing * Variables * Tests + mypy * Some docstrings * Stannis first comments * Remove add_factor * Test * Docstring * Coverage * Coverage * Coverage 100% * Remove factor group names * Remove factor group names * Modify hash + add_factors * Stannis' comments * Flattent / unflattent * Unflatten with nan * Speeding up * max size * Understand timings * Some comments * Comments * Minor * Docstring * Minor changes * Doc * Rename this_hash * Final comments * Minor Co-authored-by: stannis <[email protected]>
1 parent 58fbe95 commit dfc7535

20 files changed

+1074
-1341
lines changed

examples/gmrf.py

+46-33
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@
3838

3939
# %%
4040
# Load saved log potentials
41-
log_potentials = dict(**np.load("example_data/gmrf_log_potentials.npz"))
42-
n_clones = log_potentials.pop("n_clones")
41+
grmf_log_potentials = dict(**np.load("example_data/gmrf_log_potentials.npz"))
42+
n_clones = grmf_log_potentials.pop("n_clones")
4343
p_contour = jax.device_put(np.repeat(data["p_contour"], n_clones))
4444
prototype_targets = jax.device_put(
4545
np.array(
@@ -58,42 +58,54 @@
5858
fg = graph.FactorGraph(variables)
5959

6060
# %%
61-
# Add top-down factors
62-
fg.add_factor_group(
63-
factory=enumeration.PairwiseFactorGroup,
64-
variable_names_for_factors=[
65-
[(ii, jj), (ii + 1, jj)] for ii in range(M - 1) for jj in range(N)
61+
# Create top-down factors
62+
top_down = enumeration.PairwiseFactorGroup(
63+
variables_for_factors=[
64+
[variables[ii, jj], variables[ii + 1, jj]]
65+
for ii in range(M - 1)
66+
for jj in range(N)
6667
],
67-
name="top_down",
6868
)
69-
# Add left-right factors
70-
fg.add_factor_group(
71-
factory=enumeration.PairwiseFactorGroup,
72-
variable_names_for_factors=[
73-
[(ii, jj), (ii, jj + 1)] for ii in range(M) for jj in range(N - 1)
69+
70+
# Create left-right factors
71+
left_right = enumeration.PairwiseFactorGroup(
72+
variables_for_factors=[
73+
[variables[ii, jj], variables[ii, jj + 1]]
74+
for ii in range(M)
75+
for jj in range(N - 1)
7476
],
75-
name="left_right",
7677
)
77-
# Add diagonal factors
78-
fg.add_factor_group(
79-
factory=enumeration.PairwiseFactorGroup,
80-
variable_names_for_factors=[
81-
[(ii, jj), (ii + 1, jj + 1)] for ii in range(M - 1) for jj in range(N - 1)
78+
79+
# Create diagonal factors
80+
diagonal0 = enumeration.PairwiseFactorGroup(
81+
variables_for_factors=[
82+
[variables[ii, jj], variables[ii + 1, jj + 1]]
83+
for ii in range(M - 1)
84+
for jj in range(N - 1)
8285
],
83-
name="diagonal0",
8486
)
85-
fg.add_factor_group(
86-
factory=enumeration.PairwiseFactorGroup,
87-
variable_names_for_factors=[
88-
[(ii, jj), (ii - 1, jj + 1)] for ii in range(1, M) for jj in range(N - 1)
87+
diagonal1 = enumeration.PairwiseFactorGroup(
88+
variables_for_factors=[
89+
[variables[ii, jj], variables[ii - 1, jj + 1]]
90+
for ii in range(1, M)
91+
for jj in range(N - 1)
8992
],
90-
name="diagonal1",
9193
)
9294

95+
# Add factors
96+
fg.add_factors([top_down, left_right, diagonal0, diagonal1])
97+
9398
# %%
9499
bp = graph.BP(fg.bp_state, temperature=1.0)
95100

96101
# %%
102+
log_potentials = {
103+
top_down: grmf_log_potentials["top_down"],
104+
left_right: grmf_log_potentials["left_right"],
105+
diagonal0: grmf_log_potentials["diagonal0"],
106+
diagonal1: grmf_log_potentials["diagonal1"],
107+
}
108+
97109
n_plots = 5
98110
indices = np.random.permutation(noisy_images.shape[0])[:n_plots]
99111
fig, ax = plt.subplots(n_plots, 3, figsize=(30, 10 * n_plots))
@@ -106,14 +118,15 @@
106118
bp.get_beliefs(
107119
bp.run_bp(
108120
bp.init(
109-
evidence_updates={None: evidence},
121+
evidence_updates={variables: evidence},
110122
log_potentials_updates=log_potentials,
111123
),
112124
num_iters=15,
113125
damping=0.0,
114126
)
115127
)
116-
)
128+
)[variables]
129+
117130
pred_image = np.argmax(
118131
np.stack(
119132
[
@@ -153,15 +166,15 @@ def loss(noisy_image, target_image, log_potentials):
153166
bp.get_beliefs(
154167
bp.run_bp(
155168
bp.init(
156-
evidence_updates={None: evidence},
169+
evidence_updates={variables: evidence},
157170
log_potentials_updates=log_potentials,
158171
),
159172
num_iters=15,
160173
damping=0.0,
161174
)
162175
)
163176
)
164-
logp = jnp.mean(jnp.log(jnp.sum(target * marginals, axis=-1)))
177+
logp = jnp.mean(jnp.log(jnp.sum(target * marginals[variables], axis=-1)))
165178
return -logp
166179

167180

@@ -191,10 +204,10 @@ def update(step, batch_noisy_images, batch_target_images, opt_state):
191204
# %%
192205
opt_state = init_fun(
193206
{
194-
"top_down": np.random.randn(num_states, num_states),
195-
"left_right": np.random.randn(num_states, num_states),
196-
"diagonal0": np.random.randn(num_states, num_states),
197-
"diagonal1": np.random.randn(num_states, num_states),
207+
top_down: np.random.randn(num_states, num_states),
208+
left_right: np.random.randn(num_states, num_states),
209+
diagonal0: np.random.randn(num_states, num_states),
210+
diagonal1: np.random.randn(num_states, num_states),
198211
}
199212
)
200213

examples/ising_model.py

+21-19
Original file line numberDiff line numberDiff line change
@@ -29,21 +29,21 @@
2929

3030
# %%
3131
variables = vgroup.NDVariableArray(num_states=2, shape=(50, 50))
32-
fg = graph.FactorGraph(variables=variables)
33-
variable_names_for_factors = []
32+
fg = graph.FactorGraph(variable_groups=variables)
33+
34+
variables_for_factors = []
3435
for ii in range(50):
3536
for jj in range(50):
3637
kk = (ii + 1) % 50
3738
ll = (jj + 1) % 50
38-
variable_names_for_factors.append([(ii, jj), (kk, jj)])
39-
variable_names_for_factors.append([(ii, jj), (ii, ll)])
39+
variables_for_factors.append([variables[ii, jj], variables[kk, jj]])
40+
variables_for_factors.append([variables[ii, jj], variables[ii, ll]])
4041

41-
fg.add_factor_group(
42-
factory=enumeration.PairwiseFactorGroup,
43-
variable_names_for_factors=variable_names_for_factors,
42+
factor_group = enumeration.PairwiseFactorGroup(
43+
variables_for_factors=variables_for_factors,
4444
log_potential_matrix=0.8 * np.array([[1.0, -1.0], [-1.0, 1.0]]),
45-
name="factors",
4645
)
46+
fg.add_factors(factor_group)
4747

4848
# %% [markdown]
4949
# ### Run inference and visualize results
@@ -53,12 +53,13 @@
5353

5454
# %%
5555
bp_arrays = bp.init(
56-
evidence_updates={None: jax.device_put(np.random.gumbel(size=(50, 50, 2)))}
56+
evidence_updates={variables: jax.device_put(np.random.gumbel(size=(50, 50, 2)))}
5757
)
5858
bp_arrays = bp.run_bp(bp_arrays, num_iters=3000)
59+
beliefs = bp.get_beliefs(bp_arrays)
5960

6061
# %%
61-
img = graph.decode_map_states(bp.get_beliefs(bp_arrays))
62+
img = graph.decode_map_states(beliefs)[variables]
6263
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
6364
ax.imshow(img)
6465

@@ -73,19 +74,20 @@ def loss(log_potentials_updates, evidence_updates):
7374
)
7475
bp_arrays = bp.run_bp(bp_arrays, num_iters=3000)
7576
beliefs = bp.get_beliefs(bp_arrays)
76-
loss = -jnp.sum(beliefs)
77+
loss = -jnp.sum(beliefs[variables])
7778
return loss
7879

7980

80-
batch_loss = jax.jit(jax.vmap(loss, in_axes=(None, {None: 0}), out_axes=0))
81+
batch_loss = jax.jit(jax.vmap(loss, in_axes=(None, {variables: 0}), out_axes=0))
8182
log_potentials_grads = jax.jit(jax.grad(loss, argnums=0))
8283

8384
# %%
84-
batch_loss(None, {None: jax.device_put(np.random.gumbel(size=(10, 50, 50, 2)))})
85+
batch_loss(None, {variables: jax.device_put(np.random.gumbel(size=(10, 50, 50, 2)))})
8586

8687
# %%
8788
grads = log_potentials_grads(
88-
{"factors": jnp.eye(2)}, {None: jax.device_put(np.random.gumbel(size=(50, 50, 2)))}
89+
{factor_group: jnp.eye(2)},
90+
{variables: jax.device_put(np.random.gumbel(size=(50, 50, 2)))},
8991
)
9092

9193
# %% [markdown]
@@ -95,15 +97,15 @@ def loss(log_potentials_updates, evidence_updates):
9597
bp_state = bp.to_bp_state(bp_arrays)
9698

9799
# Query evidence for variable (0, 0)
98-
bp_state.evidence[0, 0]
100+
bp_state.evidence[variables[0, 0]]
99101

100102
# %%
101103
# Set evidence for variable (0, 0)
102-
bp_state.evidence[0, 0] = np.array([1.0, 1.0])
103-
bp_state.evidence[0, 0]
104+
bp_state.evidence[variables[0, 0]] = np.array([1.0, 1.0])
105+
bp_state.evidence[variables[0, 0]]
104106

105107
# %%
106108
# Set evidence for all variables using an array
107109
evidence = np.random.randn(50, 50, 2)
108-
bp_state.evidence[None] = evidence
109-
bp_state.evidence[10, 10] == evidence[10, 10]
110+
bp_state.evidence[variables] = evidence
111+
np.allclose(bp_state.evidence[variables[10, 10]], evidence[10, 10])

examples/pmp_binary_deconvolution.py

+30-41
Original file line numberDiff line numberDiff line change
@@ -134,15 +134,15 @@ def plot_images(images, display=True, nr=None):
134134
X = vgroup.NDVariableArray(num_states=2, shape=X_gt.shape)
135135

136136
# %% [markdown]
137-
# For computation efficiency, we add large FactorGroups via `fg.add_factor_group` instead of adding individual Factors
137+
# For computation efficiency, we construct large FactorGroups instead of individual factors
138138

139139
# %%
140140
# Factor graph
141-
fg = graph.FactorGraph(variables=dict(S=S, W=W, SW=SW, X=X))
141+
fg = graph.FactorGraph(variable_groups=[S, W, SW, X])
142142

143143
# Define the ANDFactors
144-
variable_names_for_ANDFactors = []
145-
variable_names_for_ORFactors_dict = defaultdict(list)
144+
variables_for_ANDFactors = []
145+
variables_for_ORFactors_dict = defaultdict(list)
146146
for idx_img in tqdm(range(n_images)):
147147
for idx_chan in range(n_chan):
148148
for idx_s_height in range(s_height):
@@ -152,52 +152,39 @@ def plot_images(images, display=True, nr=None):
152152
for idx_feat_width in range(feat_width):
153153
idx_img_height = idx_feat_height + idx_s_height
154154
idx_img_width = idx_feat_width + idx_s_width
155-
SW_var = (
156-
"SW",
155+
SW_var = SW[
157156
idx_img,
158157
idx_chan,
159158
idx_img_height,
160159
idx_img_width,
161160
idx_feat,
162161
idx_feat_height,
163162
idx_feat_width,
164-
)
165-
166-
variable_names_for_ANDFactor = [
167-
("S", idx_img, idx_feat, idx_s_height, idx_s_width),
168-
(
169-
"W",
170-
idx_chan,
171-
idx_feat,
172-
idx_feat_height,
173-
idx_feat_width,
174-
),
163+
]
164+
165+
variables_for_ANDFactor = [
166+
S[idx_img, idx_feat, idx_s_height, idx_s_width],
167+
W[idx_chan, idx_feat, idx_feat_height, idx_feat_width],
175168
SW_var,
176169
]
177-
variable_names_for_ANDFactors.append(
178-
variable_names_for_ANDFactor
179-
)
170+
variables_for_ANDFactors.append(variables_for_ANDFactor)
180171

181-
X_var = (idx_img, idx_chan, idx_img_height, idx_img_width)
182-
variable_names_for_ORFactors_dict[X_var].append(SW_var)
172+
X_var = X[idx_img, idx_chan, idx_img_height, idx_img_width]
173+
variables_for_ORFactors_dict[X_var].append(SW_var)
183174

184175
# Add ANDFactorGroup, which is computationally efficient
185-
fg.add_factor_group(
186-
factory=logical.ANDFactorGroup,
187-
variable_names_for_factors=variable_names_for_ANDFactors,
188-
)
176+
AND_factor_group = logical.ANDFactorGroup(variables_for_ANDFactors)
177+
fg.add_factors(AND_factor_group)
189178

190179
# Define the ORFactors
191-
variable_names_for_ORFactors = [
192-
list(tuple(variable_names_for_ORFactors_dict[X_var]) + (("X",) + X_var,))
193-
for X_var in variable_names_for_ORFactors_dict
180+
variables_for_ORFactors = [
181+
list(tuple(variables_for_ORFactors_dict[X_var]) + (X_var,))
182+
for X_var in variables_for_ORFactors_dict
194183
]
195184

196185
# Add ORFactorGroup, which is computationally efficient
197-
fg.add_factor_group(
198-
factory=logical.ORFactorGroup,
199-
variable_names_for_factors=variable_names_for_ORFactors,
200-
)
186+
OR_factor_group = logical.ORFactorGroup(variables_for_ORFactors)
187+
fg.add_factors(OR_factor_group)
201188

202189
for factor_type, factor_groups in fg.factor_groups.items():
203190
if len(factor_groups) > 0:
@@ -222,7 +209,7 @@ def plot_images(images, display=True, nr=None):
222209

223210
# %%
224211
pW = 0.25
225-
pS = 1e-100
212+
pS = 1e-75
226213
pX = 1e-100
227214

228215
# Sparsity inducing priors for W and S
@@ -237,25 +224,27 @@ def plot_images(images, display=True, nr=None):
237224
uX[..., 0] = (2 * X_gt - 1) * logit(pX)
238225

239226
# %% [markdown]
240-
# We draw a batch of samples from the posterior in parallel by transforming `run_bp`/`get_beliefs` with `jax.vmap`
227+
# We draw a batch of samples from the posterior in parallel by transforming `bp.init`/`bp.run_bp`/`bp.get_beliefs` with `jax.vmap`
241228

242229
# %%
243-
np.random.seed(seed=40)
230+
np.random.seed(seed=0)
244231
n_samples = 4
245232

246233
bp_arrays = jax.vmap(bp.init, in_axes=0, out_axes=0)(
247234
evidence_updates={
248-
"S": uS[None] + np.random.gumbel(size=(n_samples,) + uS.shape),
249-
"W": uW[None] + np.random.gumbel(size=(n_samples,) + uW.shape),
250-
"SW": np.zeros(shape=(n_samples,) + SW.shape),
251-
"X": uX[None] + np.zeros(shape=(n_samples,) + uX.shape),
235+
S: uS[None] + np.random.gumbel(size=(n_samples,) + uS.shape),
236+
W: uW[None] + np.random.gumbel(size=(n_samples,) + uW.shape),
237+
SW: np.zeros(shape=(n_samples,) + SW.shape),
238+
X: uX[None] + np.zeros(shape=(n_samples,) + uX.shape),
252239
},
253240
)
241+
254242
bp_arrays = jax.vmap(
255243
functools.partial(bp.run_bp, num_iters=100, damping=0.5),
256244
in_axes=0,
257245
out_axes=0,
258246
)(bp_arrays)
247+
259248
beliefs = jax.vmap(bp.get_beliefs, in_axes=0, out_axes=0)(bp_arrays)
260249
map_states = graph.decode_map_states(beliefs)
261250

@@ -265,4 +254,4 @@ def plot_images(images, display=True, nr=None):
265254
# Because we have used one extra feature for inference, each posterior sample recovers the 4 basic features used to generate the images, and includes an extra symbol.
266255

267256
# %%
268-
_ = plot_images(map_states["W"].reshape(-1, feat_height, feat_width), nr=n_samples)
257+
_ = plot_images(map_states[W].reshape(-1, feat_height, feat_width), nr=n_samples)

0 commit comments

Comments
 (0)