Skip to content
This repository was archived by the owner on Dec 5, 2024. It is now read-only.

Commit 4231b12

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 9eaf2be commit 4231b12

File tree

5 files changed

+5
-2
lines changed

5 files changed

+5
-2
lines changed

examples/gmrf.py

+1
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@
155155
#
156156
# The following training loop requires a GPU with at least 11 GB of memory.
157157

158+
158159
# %%
159160
@jax.jit
160161
def loss(noisy_image, target_image, log_potentials):

examples/ising_model.py

+1
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
# %% [markdown]
6666
# ### Gradients and batching
6767

68+
6869
# %%
6970
def loss(log_potentials_updates, evidence_updates):
7071
bp_arrays = bp.init(

examples/rcn.py

+3
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ def fetch_mnist_dataset(test_size: int, seed: int = 5) -> tuple[np.ndarray, np.n
225225
# %% [markdown]
226226
# ### 4.2.1 Pre-compute the valid configs for different perturb radius.
227227

228+
228229
# %%
229230
def valid_configs(r: int, hps: int, vps: int) -> np.ndarray:
230231
"""Returns the valid configurations for a factor given the perturb radius.
@@ -294,6 +295,7 @@ def valid_configs(r: int, hps: int, vps: int) -> np.ndarray:
294295
# %% [markdown]
295296
# ## 5.1 Helper functions to initialize the evidence for a given image
296297

298+
297299
# %%
298300
def get_bu_msg(img: np.ndarray) -> np.ndarray:
299301
"""Computes the bottom-up messages given a test image.
@@ -365,6 +367,7 @@ def get_bu_msg(img: np.ndarray) -> np.ndarray:
365367
# %% [markdown]
366368
# ## 5.2 Run MAP inference on all test images
367369

370+
368371
# %%
369372
def get_evidence(bu_msg: np.ndarray, frc: np.ndarray) -> np.ndarray:
370373
"""Returns the evidence (shape (n_frcs, M)).

pgmax/factor/enum.py

-1
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,6 @@ def pass_enum_fac_to_var_messages(
256256
num_val_configs: int,
257257
temperature: float,
258258
) -> jnp.ndarray:
259-
260259
"""Passes messages from EnumFactors to Variables.
261260
262261
The update is performed in two steps. First, a "summary" array is generated that has an entry for every valid

pgmax/factor/logical.py

-1
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,6 @@ def pass_logical_fac_to_var_messages(
312312
temperature: float,
313313
log_potentials: Optional[jnp.ndarray] = None,
314314
) -> jnp.ndarray:
315-
316315
"""Passes messages from LogicalFactors to Variables.
317316
318317
Args:

0 commit comments

Comments
 (0)