Skip to content
This repository was archived by the owner on Dec 5, 2024. It is now read-only.

Commit 975750f

Browse files
authored
Make FactorGraph aware of factor groups (#59)
* Make factor graph aware of factor groups * Docstrings * Fix tests * Address comments
1 parent f8555a4 commit 975750f

File tree

4 files changed

+63
-23
lines changed

4 files changed

+63
-23
lines changed

.isort.cfg

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
[settings]
2-
profile=black
2+
profile=black

pgmax/fg/graph.py

+26-16
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,7 @@
1414

1515
@dataclass
1616
class FactorGraph:
17-
"""Base class to represent a factor graph.
18-
19-
Concrete factor graphs inherits from this class, and specifies get_evidence to generate
20-
the evidence array, and optionally init_msgs (default to initializing all messages to 0)
17+
"""Class for representing a factor graph
2118
2219
Args:
2320
variable_groups: A container containing multiple VariableGroups, or a CompositeVariableGroup.
@@ -32,7 +29,7 @@ class FactorGraph:
3229
3330
Attributes:
3431
_composite_variable_group: CompositeVariableGroup. contains all involved VariableGroups
35-
_factors: list. contains all involved factors
32+
_factor_groups: List of added factor groups
3633
num_var_states: int. represents the sum of all variable states of all variables in the
3734
FactorGraph
3835
_vars_to_starts: MappingProxyType[nodes.Variable, int]. maps every variable to an int
@@ -84,7 +81,7 @@ def __post_init__(self):
8481

8582
self._vars_to_evidence: Dict[nodes.Variable, np.ndarray] = {}
8683

87-
self._factors: List[nodes.EnumerationFactor] = []
84+
self._factor_groups: List[groups.FactorGroup] = []
8885

8986
def add_factors(
9087
self,
@@ -113,20 +110,24 @@ def add_factors(
113110
"""
114111
factor_factory = kwargs.pop("factor_factory", None)
115112
if factor_factory is not None:
116-
factors = factor_factory(
113+
factor_group = factor_factory(
117114
self._composite_variable_group, *args, **kwargs
118-
).factors
115+
)
119116
else:
120117
if len(args) > 0:
121118
new_args = list(args)
122-
new_args[0] = tuple(self._composite_variable_group[args[0]])
123-
factors = [nodes.EnumerationFactor(*new_args, **kwargs)]
119+
new_args[0] = [args[0]]
120+
factor_group = groups.EnumerationFactorGroup(
121+
self._composite_variable_group, *new_args, **kwargs
122+
)
124123
else:
125124
keys = kwargs.pop("keys")
126-
kwargs["variables"] = self._composite_variable_group[keys]
127-
factors = [nodes.EnumerationFactor(**kwargs)]
125+
kwargs["connected_var_keys"] = [keys]
126+
factor_group = groups.EnumerationFactorGroup(
127+
self._composite_variable_group, **kwargs
128+
)
128129

129-
self._factors.extend(factors)
130+
self._factor_groups.append(factor_group)
130131

131132
@property
132133
def wiring(self) -> nodes.EnumerationWiring:
@@ -138,7 +139,8 @@ def wiring(self) -> nodes.EnumerationWiring:
138139
compiled wiring from each individual factor
139140
"""
140141
wirings = [
141-
factor.compile_wiring(self._vars_to_starts) for factor in self._factors
142+
factor_group.compile_wiring(self._vars_to_starts)
143+
for factor_group in self._factor_groups
142144
]
143145
wiring = fg_utils.concatenate_enumeration_wirings(wirings)
144146
return wiring
@@ -154,7 +156,10 @@ def factor_configs_log_potentials(self) -> np.ndarray:
154156
valid configuration
155157
"""
156158
return np.concatenate(
157-
[factor.factor_configs_log_potentials for factor in self._factors]
159+
[
160+
factor_group.factor_group_log_potentials
161+
for factor_group in self._factor_groups
162+
]
158163
)
159164

160165
@property
@@ -182,6 +187,11 @@ def evidence(self) -> np.ndarray:
182187

183188
return evidence
184189

190+
@property
191+
def factors(self) -> Tuple[nodes.EnumerationFactor, ...]:
192+
"""List of individual factors in the factor graph"""
193+
return sum([factor_group.factors for factor_group in self._factor_groups], ())
194+
185195
def get_init_msgs(self, context: Any = None):
186196
"""Function to initialize messages.
187197
@@ -201,7 +211,7 @@ def set_evidence(
201211
self,
202212
key: Union[Tuple[Any, ...], Any],
203213
evidence: Union[Dict[Any, np.ndarray], np.ndarray],
204-
):
214+
) -> None:
205215
"""Function to update the evidence for variables in the FactorGraph.
206216
207217
Args:

pgmax/fg/groups.py

+34-4
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import numpy as np
88

99
import pgmax.fg.nodes as nodes
10+
from pgmax.fg import fg_utils
1011
from pgmax.utils import cached_property
1112

1213

@@ -375,6 +376,35 @@ def __post_init__(self) -> None:
375376
def factors(self) -> Tuple[nodes.EnumerationFactor, ...]:
376377
raise NotImplementedError("Needs to be overriden by subclass")
377378

379+
def compile_wiring(
380+
self, vars_to_starts: Mapping[nodes.Variable, int]
381+
) -> nodes.EnumerationWiring:
382+
"""Function to compile wiring for the factor group.
383+
384+
Args:
385+
vars_to_starts: A dictionary that maps variables to their global starting indices
386+
For an n-state variable, a global start index of m means the global indices
387+
of its n variable states are m, m + 1, ..., m + n - 1
388+
389+
Returns:
390+
compiled wiring for the factor group
391+
"""
392+
wirings = [factor.compile_wiring(vars_to_starts) for factor in self.factors]
393+
wiring = fg_utils.concatenate_enumeration_wirings(wirings)
394+
return wiring
395+
396+
@cached_property
397+
def factor_group_log_potentials(self) -> np.ndarray:
398+
"""Function to compile potential array for the factor group
399+
400+
Returns:
401+
a jnp array representing the log of the potential function for
402+
the factor group
403+
"""
404+
return np.concatenate(
405+
[factor.factor_configs_log_potentials for factor in self.factors]
406+
)
407+
378408

379409
@dataclass(frozen=True, eq=False)
380410
class EnumerationFactorGroup(FactorGroup):
@@ -387,14 +417,14 @@ class EnumerationFactorGroup(FactorGroup):
387417
Args:
388418
factor_configs: Array of shape (num_val_configs, num_variables)
389419
An array containing explicit enumeration of all valid configurations
420+
factor_configs_log_potentials: Optional array of shape (num_val_configs,).
421+
If specified, it contains the log of the potential value for every possible configuration.
422+
If none, it is assumed the log potential is uniform 0 and such an array is automatically
423+
initialized.
390424
391425
Attributes:
392426
factors: a tuple of all the factors belonging to this group. These are constructed
393427
internally by invoking the _get_connected_var_keys_for_factors method.
394-
factor_configs_log_potentials: Optional ndarray of shape (num_val_configs,).
395-
if specified. Must contain the log of the potential value for every possible
396-
configuration. If left unspecified, it is assumed the log potential is uniform
397-
0 and such an array is automatically initialized.
398428
"""
399429

400430
factor_configs: np.ndarray

tests/test_pgmax.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ def create_valid_suppression_config_arr(suppression_diameter):
306306
else:
307307
fg.add_factors(
308308
keys=curr_keys,
309-
configs=valid_configs_non_supp,
309+
factor_configs=valid_configs_non_supp,
310310
factor_configs_log_potentials=np.zeros(
311311
valid_configs_non_supp.shape[0], dtype=float
312312
),
@@ -417,4 +417,4 @@ def binary_connected_variables(
417417

418418
assert isinstance(fg.evidence, np.ndarray)
419419

420-
assert len(fg._factors) == 7056
420+
assert len(fg.factors) == 7056

0 commit comments

Comments
 (0)