diff --git a/examples/gmrf.py b/examples/gmrf.py
index 5228d0d..8b330c4 100644
--- a/examples/gmrf.py
+++ b/examples/gmrf.py
@@ -155,6 +155,7 @@
#
# The following training loop requires a GPU with at least 11 GB of memory.
+
# %%
@jax.jit
def loss(noisy_image, target_image, log_potentials):
diff --git a/examples/ising_model.py b/examples/ising_model.py
index 1befd27..5a2826b 100644
--- a/examples/ising_model.py
+++ b/examples/ising_model.py
@@ -65,6 +65,7 @@
# %% [markdown]
# ### Gradients and batching
+
# %%
def loss(log_potentials_updates, evidence_updates):
bp_arrays = bp.init(
diff --git a/examples/rcn.py b/examples/rcn.py
index df5a30d..24a7eaf 100644
--- a/examples/rcn.py
+++ b/examples/rcn.py
@@ -225,6 +225,7 @@ def fetch_mnist_dataset(test_size: int, seed: int = 5) -> tuple[np.ndarray, np.n
# %% [markdown]
# ### 4.2.1 Pre-compute the valid configs for different perturb radius.
+
# %%
def valid_configs(r: int, hps: int, vps: int) -> np.ndarray:
"""Returns the valid configurations for a factor given the perturb radius.
@@ -294,6 +295,7 @@ def valid_configs(r: int, hps: int, vps: int) -> np.ndarray:
# %% [markdown]
# ## 5.1 Helper functions to initialize the evidence for a given image
+
# %%
def get_bu_msg(img: np.ndarray) -> np.ndarray:
"""Computes the bottom-up messages given a test image.
@@ -365,6 +367,7 @@ def get_bu_msg(img: np.ndarray) -> np.ndarray:
# %% [markdown]
# ## 5.2 Run MAP inference on all test images
+
# %%
def get_evidence(bu_msg: np.ndarray, frc: np.ndarray) -> np.ndarray:
"""Returns the evidence (shape (n_frcs, M)).
diff --git a/pgmax/factor/enum.py b/pgmax/factor/enum.py
index 856965e..bd27075 100644
--- a/pgmax/factor/enum.py
+++ b/pgmax/factor/enum.py
@@ -256,7 +256,6 @@ def pass_enum_fac_to_var_messages(
num_val_configs: int,
temperature: float,
) -> jnp.ndarray:
-
"""Passes messages from EnumFactors to Variables.
The update is performed in two steps. First, a "summary" array is generated that has an entry for every valid
diff --git a/pgmax/factor/logical.py b/pgmax/factor/logical.py
index b9d9fcd..97e19a9 100644
--- a/pgmax/factor/logical.py
+++ b/pgmax/factor/logical.py
@@ -312,7 +312,6 @@ def pass_logical_fac_to_var_messages(
temperature: float,
log_potentials: Optional[jnp.ndarray] = None,
) -> jnp.ndarray:
-
"""Passes messages from LogicalFactors to Variables.
Args: