14
14
15
15
@dataclass
16
16
class FactorGraph :
17
- """Base class to represent a factor graph.
18
-
19
- Concrete factor graphs inherits from this class, and specifies get_evidence to generate
20
- the evidence array, and optionally init_msgs (default to initializing all messages to 0)
17
+ """Class for representing a factor graph
21
18
22
19
Args:
23
20
variable_groups: A container containing multiple VariableGroups, or a CompositeVariableGroup.
@@ -32,7 +29,7 @@ class FactorGraph:
32
29
33
30
Attributes:
34
31
_composite_variable_group: CompositeVariableGroup. contains all involved VariableGroups
35
- _factors: list. contains all involved factors
32
+ _factor_groups: List of added factor groups
36
33
num_var_states: int. represents the sum of all variable states of all variables in the
37
34
FactorGraph
38
35
_vars_to_starts: MappingProxyType[nodes.Variable, int]. maps every variable to an int
@@ -84,7 +81,7 @@ def __post_init__(self):
84
81
85
82
self ._vars_to_evidence : Dict [nodes .Variable , np .ndarray ] = {}
86
83
87
- self ._factors : List [nodes . EnumerationFactor ] = []
84
+ self ._factor_groups : List [groups . FactorGroup ] = []
88
85
89
86
def add_factors (
90
87
self ,
@@ -113,20 +110,24 @@ def add_factors(
113
110
"""
114
111
factor_factory = kwargs .pop ("factor_factory" , None )
115
112
if factor_factory is not None :
116
- factors = factor_factory (
113
+ factor_group = factor_factory (
117
114
self ._composite_variable_group , * args , ** kwargs
118
- ). factors
115
+ )
119
116
else :
120
117
if len (args ) > 0 :
121
118
new_args = list (args )
122
- new_args [0 ] = tuple (self ._composite_variable_group [args [0 ]])
123
- factors = [nodes .EnumerationFactor (* new_args , ** kwargs )]
119
+ new_args [0 ] = [args [0 ]]
120
+ factor_group = groups .EnumerationFactorGroup (
121
+ self ._composite_variable_group , * new_args , ** kwargs
122
+ )
124
123
else :
125
124
keys = kwargs .pop ("keys" )
126
- kwargs ["variables" ] = self ._composite_variable_group [keys ]
127
- factors = [nodes .EnumerationFactor (** kwargs )]
125
+ kwargs ["connected_var_keys" ] = [keys ]
126
+ factor_group = groups .EnumerationFactorGroup (
127
+ self ._composite_variable_group , ** kwargs
128
+ )
128
129
129
- self ._factors . extend ( factors )
130
+ self ._factor_groups . append ( factor_group )
130
131
131
132
@property
132
133
def wiring (self ) -> nodes .EnumerationWiring :
@@ -138,7 +139,8 @@ def wiring(self) -> nodes.EnumerationWiring:
138
139
compiled wiring from each individual factor
139
140
"""
140
141
wirings = [
141
- factor .compile_wiring (self ._vars_to_starts ) for factor in self ._factors
142
+ factor_group .compile_wiring (self ._vars_to_starts )
143
+ for factor_group in self ._factor_groups
142
144
]
143
145
wiring = fg_utils .concatenate_enumeration_wirings (wirings )
144
146
return wiring
@@ -154,7 +156,10 @@ def factor_configs_log_potentials(self) -> np.ndarray:
154
156
valid configuration
155
157
"""
156
158
return np .concatenate (
157
- [factor .factor_configs_log_potentials for factor in self ._factors ]
159
+ [
160
+ factor_group .factor_group_log_potentials
161
+ for factor_group in self ._factor_groups
162
+ ]
158
163
)
159
164
160
165
@property
@@ -182,6 +187,11 @@ def evidence(self) -> np.ndarray:
182
187
183
188
return evidence
184
189
190
+ @property
191
+ def factors (self ) -> Tuple [nodes .EnumerationFactor , ...]:
192
+ """List of individual factors in the factor graph"""
193
+ return sum ([factor_group .factors for factor_group in self ._factor_groups ], ())
194
+
185
195
def get_init_msgs (self , context : Any = None ):
186
196
"""Function to initialize messages.
187
197
@@ -201,7 +211,7 @@ def set_evidence(
201
211
self ,
202
212
key : Union [Tuple [Any , ...], Any ],
203
213
evidence : Union [Dict [Any , np .ndarray ], np .ndarray ],
204
- ):
214
+ ) -> None :
205
215
"""Function to update the evidence for variables in the FactorGraph.
206
216
207
217
Args:
0 commit comments