Skip to content
This repository has been archived by the owner on Dec 5, 2024. It is now read-only.
Permalink

Comparing changes

This is a direct comparison between two commits made in this repository or its related repositories. View the default comparison for this range or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: vicariousinc/PGMax
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: 2f7219456afb6871792fc0334f655fc1d6ebf5e2
Choose a base ref
..
head repository: vicariousinc/PGMax
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: 9eaf2befe501197547240559c5ebcd71ad6618cc
Choose a head ref
Showing with 3 additions and 6 deletions.
  1. +1 −1 .pre-commit-config.yaml
  2. +0 −1 examples/gmrf.py
  3. +0 −1 examples/ising_model.py
  4. +0 −3 examples/rcn.py
  5. +1 −0 pgmax/factor/enum.py
  6. +1 −0 pgmax/factor/logical.py
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -28,7 +28,7 @@ repos:
verbose: true

- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v1.1.1' # Use the sha / tag you want to point at
rev: 'v1.4.1' # Use the sha / tag you want to point at
hooks:
- id: mypy
additional_dependencies: [tokenize-rt==3.2.0]
1 change: 0 additions & 1 deletion examples/gmrf.py
Original file line number Diff line number Diff line change
@@ -155,7 +155,6 @@
#
# The following training loop requires a GPU with at least 11 GB of memory.


# %%
@jax.jit
def loss(noisy_image, target_image, log_potentials):
1 change: 0 additions & 1 deletion examples/ising_model.py
Original file line number Diff line number Diff line change
@@ -65,7 +65,6 @@
# %% [markdown]
# ### Gradients and batching


# %%
def loss(log_potentials_updates, evidence_updates):
bp_arrays = bp.init(
3 changes: 0 additions & 3 deletions examples/rcn.py
Original file line number Diff line number Diff line change
@@ -225,7 +225,6 @@ def fetch_mnist_dataset(test_size: int, seed: int = 5) -> tuple[np.ndarray, np.n
# %% [markdown]
# ### 4.2.1 Pre-compute the valid configs for different perturb radius.


# %%
def valid_configs(r: int, hps: int, vps: int) -> np.ndarray:
"""Returns the valid configurations for a factor given the perturb radius.
@@ -295,7 +294,6 @@ def valid_configs(r: int, hps: int, vps: int) -> np.ndarray:
# %% [markdown]
# ## 5.1 Helper functions to initialize the evidence for a given image


# %%
def get_bu_msg(img: np.ndarray) -> np.ndarray:
"""Computes the bottom-up messages given a test image.
@@ -367,7 +365,6 @@ def get_bu_msg(img: np.ndarray) -> np.ndarray:
# %% [markdown]
# ## 5.2 Run MAP inference on all test images


# %%
def get_evidence(bu_msg: np.ndarray, frc: np.ndarray) -> np.ndarray:
"""Returns the evidence (shape (n_frcs, M)).
1 change: 1 addition & 0 deletions pgmax/factor/enum.py
Original file line number Diff line number Diff line change
@@ -256,6 +256,7 @@ def pass_enum_fac_to_var_messages(
num_val_configs: int,
temperature: float,
) -> jnp.ndarray:

"""Passes messages from EnumFactors to Variables.
The update is performed in two steps. First, a "summary" array is generated that has an entry for every valid
1 change: 1 addition & 0 deletions pgmax/factor/logical.py
Original file line number Diff line number Diff line change
@@ -312,6 +312,7 @@ def pass_logical_fac_to_var_messages(
temperature: float,
log_potentials: Optional[jnp.ndarray] = None,
) -> jnp.ndarray:

"""Passes messages from LogicalFactors to Variables.
Args: