Skip to content
This repository was archived by the owner on Dec 5, 2024. It is now read-only.

Commit 107dd8e

Browse files
Changes to Docs (#142)
* Docs * Comments * Typo
1 parent c62b624 commit 107dd8e

File tree

10 files changed

+56
-57
lines changed

10 files changed

+56
-57
lines changed

examples/rbm.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@
2929
from pgmax import fgraph, fgroup, infer, vgroup
3030

3131
# %% [markdown]
32-
# The [`pgmax.fgraph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fgraph.html#module-pgmax.fgraph) module contains classes for specifying factor graphs, the [`pgmax.fgroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fgroup.html#module-pgmax.vgroup) module contains classes for specifying groups of variables, the [`pgmax.vgroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fgroup.html#module-pgmax.fgroup) module contains classes for specifying groups of factors and the [`pgmax.infer`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.infer.html#module-pgmax.infer) module containing core functions to perform LBP.
32+
# The [`pgmax.fgraph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fgraph.html#module-pgmax.fgraph) module contains classes for specifying factor graphs, the [`pgmax.vgroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.vgroup.html#module-pgmax.vgroup) module contains classes for specifying groups of variables, the [`pgmax.fgroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fgroup.html#module-pgmax.fgroup) module contains classes for specifying groups of factors and the [`pgmax.infer`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.infer.html#module-pgmax.infer) module contains functions to perform LBP.
3333
#
34-
# We next load the RBM trained in Sec. 5.5 of the [PMP paper](https://proceedings.neurips.cc/paper/2021/hash/07b1c04a30f798b5506c1ec5acfb9031-Abstract.html) on MNIST digits.
34+
# We next load the RBM trained in Sec. 5.5 of the [PMP paper](https://proceedings.neurips.cc/paper/2021/hash/07b1c04a30f798b5506c1ec5acfb9031-Abstract.html) which has been trained on MNIST digits.
3535

3636
# %%
3737
# Load parameters
@@ -50,9 +50,9 @@
5050
fg = fgraph.FactorGraph(variable_groups=[hidden_variables, visible_variables])
5151

5252
# %% [markdown]
53-
# [`NDVarArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVarArray.html#pgmax.fg.groups.NDVarArray) is a convenient class for specifying a group of variables living on a multidimensional grid with the same number of states, and shares some similarities with [`numpy.ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html). The [`FactorGraph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.FactorGraph.html#pgmax.fg.graph.FactorGraph) `fg` is initialized with a set of variables, which can be either a single [`VarGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.VarGroup.html#pgmax.fg.groups.VarGroup) (e.g. an [`NDVarArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVarArray.html#pgmax.fg.groups.NDVarArray)), or a list of [`VarGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.VarGroup.html#pgmax.fg.groups.VarGroup)s. Once initialized, the set of variables in `fg` is fixed and cannot be changed.
53+
# [`NDVarArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.vgroup.varray.NDVarArray.html#pgmax.vgroup.varray.NDVarArray) is a convenient class for specifying a group of variables living on a multidimensional grid with possibly different number of states: this class shares some similarities with [`numpy.ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html). The [`FactorGraph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fgraph.fgraph.FactorGraph.html#pgmax.fgraph.fgraph.FactorGraph) `fg` is initialized with a set of variables, which can be either a single [`VarGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.vgroup.vgroup.VarGroup.html#pgmax.vgroup.vgroup.VarGroup) (e.g. an [`NDVarArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.vgroup.varray.NDVarArray.html#pgmax.vgroup.varray.NDVarArray)), or a list of [`VarGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.vgroup.vgroup.VarGroup.html#pgmax.vgroup.vgroup.VarGroup)s. Once initialized, the set of variables in `fg` is fixed and cannot be changed.
5454
#
55-
# After initialization, `fg` does not have any factors. PGMax supports imperatively adding factors to a [`FactorGraph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.FactorGraph.html#pgmax.fg.graph.FactorGraph). We can add the unary and pairwise factors by grouping them using [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup)
55+
# After initialization, `fg` does not have any factors. PGMax supports imperatively adding factors to a [`FactorGraph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fgraph.fgraph.FactorGraph.html#pgmax.fgraph.fgraph.FactorGraph). We efficiently add the unary and pairwise factors by grouping them using [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fgroup.fgroup.FactorGroup.html#pgmax.fgroup.fgroup.FactorGroup)
5656

5757
# %%
5858
# Create unary factors
@@ -86,13 +86,14 @@
8686

8787

8888
# %% [markdown]
89-
# PGMax implements convenient and computationally efficient [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup) for representing groups of similar factors. The code above makes use of [`EnumFactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.EnumFactorGroup.html#pgmax.fg.groups.EnumFactorGroup) and [`PairwiseFactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.PairwiseFactorGroup.html#pgmax.fg.groups.PairwiseFactorGroup).
89+
# PGMax implements convenient and computationally efficient [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fgroup.fgroup.FactorGroup.html#pgmax.fgroup.fgroup.FactorGroup) for representing groups of similar factors. The code above makes use of [`EnumFactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fgroup.enum.EnumFactorGroup.html#pgmax.fgroup.enum.EnumFactorGroup) and [`PairwiseFactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fgroup.enum.PairwiseFactorGroup.html#pgmax.fgroup.enum.PairwiseFactorGroup).
9090
#
91-
# A [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup) takes as argument `variables_for_factors` which is a list of lists of the variables involved in the different factors, and additional arguments specific to each [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup) (e.g. `factor_configs` or `log_potential_matrix` here).
91+
# A [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fgroup.fgroup.FactorGroup.html#pgmax.fgroup.fgroup.FactorGroup) takes as argument `variables_for_factors` which is a list of lists of the variables involved in the different factors, and additional arguments specific to each [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fgroup.fgroup.FactorGroup.html#pgmax.fgroup.fgroup.FactorGroup) (e.g. `factor_configs` and `log_potential_matrix` here).
9292
#
93-
# In this example, since we construct `fg` with variables `hidden_variables` and `visible_variables`, which are both [`NDVarArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVarArray.html#pgmax.fg.groups.NDVarArray)s, we can refer to the `ii`th hidden variable as `hidden_variables[ii]` and the `jj`th visible variable as `visible_variables[jj]`.
93+
# In this example, since we construct `fg` with variables `hidden_variables` and `visible_variables`, which are both [`NDVarArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.vgroup.varray.NDVarArray.html#pgmax.vgroup.varray.NDVarArray)s, we can refer to the `ii`th hidden variable as `hidden_variables[ii]` and the `jj`th visible variable as `visible_variables[jj]`.
94+
#
95+
# An alternative way of creating the above factors is to add them iteratively without building the [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fgroup.fgroup.FactorGroup.html#pgmax.fgroup.fgroup.FactorGroup)s as below. This approach is not recommended as it can be much slower than using [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fgroup.fgroup.FactorGroup.html#pgmax.fgroup.fgroup.FactorGroup)s.
9496
#
95-
# An alternative way of creating the above factors is to add them iteratively without building the [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup)s as below. This approach is not recommended as it is not computationally efficient.
9697
# ~~~python
9798
# from pgmax import factor
9899
# import itertools
@@ -176,7 +177,7 @@
176177
# ~~~python
177178
# bp = infer.BP(fg.bp_state, temperature=T)
178179
# ~~~
179-
# where the arguments of the `this_bp` are several useful functions to run LBP. In particular, `bp.init`, `bp.run_bp`, `bp.get_beliefs` are pure functions with no side-effects. This design choice means that we can easily apply JAX transformations like `jit`/`vmap`/`grad`, etc., to these functions, and additionally allows PGMax to seamlessly interact with other packages in the rapidly growing JAX ecosystem (see [here](https://deepmind.com/blog/article/using-jax-to-accelerate-our-research) and [here](https://github.com/n2cholas/awesome-jax)).
180+
# where the arguments of `bp` are several useful functions to run LBP. In particular, `bp.init`, `bp.run_bp`, `bp.get_beliefs` are pure functions with no side-effects. This design choice means that we can easily apply JAX transformations like `jit`/`vmap`/`grad`, etc., to these functions, and additionally allows PGMax to seamlessly interact with other packages in the rapidly growing JAX ecosystem (see [here](https://deepmind.com/blog/article/using-jax-to-accelerate-our-research) and [here](https://github.com/n2cholas/awesome-jax)).
180181
#
181182
# As an example of applying `jax.vmap` to `bp.init`/`bp.run_bp`/`bp.get_beliefs` to process a batch of samples/models in parallel, instead of drawing one sample at a time as above, we can draw a batch of samples in parallel as follows:
182183

pgmax/factor/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""A sub-package defining factors containing different types of factors."""
1+
"""A sub-package defining different types of factors."""
22

33
import collections
44
from typing import Callable, OrderedDict, Type

pgmax/factor/factor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""A module containing classes that specify the basic components of a factor."""
1+
"""A module containing the base classes for factors in a factor graph."""
22

33
from dataclasses import asdict, dataclass
44
from typing import List, Sequence, Tuple, Union

pgmax/factor/logical.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ class ORFactor(LogicalFactor):
226226
@dataclass(frozen=True, eq=False)
227227
class ANDFactor(LogicalFactor):
228228
"""An AND factor of the form (p1,...,pn, c)
229-
where p1,...,pn are the parents variables and c is the child variable.
229+
where p1,...,pn are the parents variables and c is the child variable.
230230
231231
An AND factor is defined as:
232232
F(p1, p2, ..., pn, c) = 0 <=> c = AND(p1, p2, ..., pn)

pgmax/fgraph/fgraph.py

+38-40
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
from __future__ import annotations
2-
3-
"""A module containing the core class to specify a Factor Graph."""
1+
"""A module containing the core class to build a factor graph."""
42

53
import collections
64
import copy
@@ -28,6 +26,43 @@
2826
from pgmax.utils import cached_property
2927

3028

29+
@dataclass(frozen=True, eq=False)
30+
class FactorGraphState:
31+
"""FactorGraphState.
32+
33+
Args:
34+
variable_groups: VarGroups in the FactorGraph.
35+
vars_to_starts: Maps variables to their starting indices in the flat evidence array.
36+
flat_evidence[vars_to_starts[variable]: vars_to_starts[variable] + variable.num_var_states]
37+
contains evidence to the variable.
38+
num_var_states: Total number of variable states.
39+
total_factor_num_states: Size of the flat ftov messages array.
40+
factor_type_to_msgs_range: Maps factors types to their start and end indices in the flat ftov messages.
41+
factor_type_to_potentials_range: Maps factor types to their start and end indices in the flat log potentials.
42+
factor_group_to_potentials_starts: Maps factor groups to their starting indices in the flat log potentials.
43+
log_potentials: Flat log potentials array concatenated for each factor type.
44+
wiring: Wiring derived for each factor type.
45+
"""
46+
47+
variable_groups: Sequence[vgroup.VarGroup]
48+
vars_to_starts: Mapping[Tuple[int, int], int]
49+
num_var_states: int
50+
total_factor_num_states: int
51+
factor_type_to_msgs_range: OrderedDict[type, Tuple[int, int]]
52+
factor_type_to_potentials_range: OrderedDict[type, Tuple[int, int]]
53+
factor_group_to_potentials_starts: OrderedDict[fgroup.FactorGroup, int]
54+
log_potentials: OrderedDict[type, Union[None, np.ndarray]]
55+
wiring: OrderedDict[type, factor.Wiring]
56+
57+
def __post_init__(self):
58+
for field in self.__dataclass_fields__:
59+
if isinstance(getattr(self, field), np.ndarray):
60+
getattr(self, field).flags.writeable = False
61+
62+
if isinstance(getattr(self, field), Mapping):
63+
object.__setattr__(self, field, MappingProxyType(getattr(self, field)))
64+
65+
3166
@dataclass
3267
class FactorGraph:
3368
"""Class for representing a factor graph.
@@ -294,40 +329,3 @@ def bp_state(self) -> Any:
294329
ftov_msgs=bp_state.FToVMessages(fg_state=self.fg_state),
295330
evidence=bp_state.Evidence(fg_state=self.fg_state),
296331
)
297-
298-
299-
@dataclass(frozen=True, eq=False)
300-
class FactorGraphState:
301-
"""FactorGraphState.
302-
303-
Args:
304-
variable_groups: VarGroups in the FactorGraph.
305-
vars_to_starts: Maps variables to their starting indices in the flat evidence array.
306-
flat_evidence[vars_to_starts[variable]: vars_to_starts[variable] + variable.num_var_states]
307-
contains evidence to the variable.
308-
num_var_states: Total number of variable states.
309-
total_factor_num_states: Size of the flat ftov messages array.
310-
factor_type_to_msgs_range: Maps factors types to their start and end indices in the flat ftov messages.
311-
factor_type_to_potentials_range: Maps factor types to their start and end indices in the flat log potentials.
312-
factor_group_to_potentials_starts: Maps factor groups to their starting indices in the flat log potentials.
313-
log_potentials: Flat log potentials array concatenated for each factor type.
314-
wiring: Wiring derived for each factor type.
315-
"""
316-
317-
variable_groups: Sequence[vgroup.VarGroup]
318-
vars_to_starts: Mapping[Tuple[int, int], int]
319-
num_var_states: int
320-
total_factor_num_states: int
321-
factor_type_to_msgs_range: OrderedDict[type, Tuple[int, int]]
322-
factor_type_to_potentials_range: OrderedDict[type, Tuple[int, int]]
323-
factor_group_to_potentials_starts: OrderedDict[fgroup.FactorGroup, int]
324-
log_potentials: OrderedDict[type, None | np.ndarray]
325-
wiring: OrderedDict[type, factor.Wiring]
326-
327-
def __post_init__(self):
328-
for field in self.__dataclass_fields__:
329-
if isinstance(getattr(self, field), np.ndarray):
330-
getattr(self, field).flags.writeable = False
331-
332-
if isinstance(getattr(self, field), Mapping):
333-
object.__setattr__(self, field, MappingProxyType(getattr(self, field)))

pgmax/fgroup/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""A sub-package defining factor groups and containing different types of factor groups."""
1+
"""A sub-package defining different types of groups of factors."""
22

33
from .enum import EnumFactorGroup, PairwiseFactorGroup
44
from .fgroup import FactorGroup, SingleFactorGroup

pgmax/fgroup/fgroup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""A module containing the base classes for factor groups in a Factor Graph."""
1+
"""A module containing the base classes for factor groups in a factor graph."""
22

33
import inspect
44
from dataclasses import dataclass, field

pgmax/infer/bp_state.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"Defines container classes for belief propagation states, and for the relevant flat arrays used in belief propagation."
1+
"A module defining container classes for belief propagation states."
22

33
import functools
44
from dataclasses import asdict, dataclass

pgmax/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""A module containing helper functions useful while constructing Factor Graphs."""
1+
"""A module containing helper functions."""
22

33
import functools
44
from typing import Callable

pgmax/vgroup/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""A sub-package defining variable groups and containing different types of variable groups."""
1+
"""A sub-package defining different types of groups of variables."""
22

33
from .varray import NDVarArray
44
from .vdict import VarDict

0 commit comments

Comments
 (0)