|
1 |
| -from __future__ import annotations |
2 |
| - |
3 |
| -"""A module containing the core class to specify a Factor Graph.""" |
| 1 | +"""A module containing the core class to build a factor graph.""" |
4 | 2 |
|
5 | 3 | import collections
|
6 | 4 | import copy
|
|
28 | 26 | from pgmax.utils import cached_property
|
29 | 27 |
|
30 | 28 |
|
| 29 | +@dataclass(frozen=True, eq=False) |
| 30 | +class FactorGraphState: |
| 31 | + """FactorGraphState. |
| 32 | +
|
| 33 | + Args: |
| 34 | + variable_groups: VarGroups in the FactorGraph. |
| 35 | + vars_to_starts: Maps variables to their starting indices in the flat evidence array. |
| 36 | + flat_evidence[vars_to_starts[variable]: vars_to_starts[variable] + variable.num_var_states] |
| 37 | + contains evidence to the variable. |
| 38 | + num_var_states: Total number of variable states. |
| 39 | + total_factor_num_states: Size of the flat ftov messages array. |
| 40 | + factor_type_to_msgs_range: Maps factors types to their start and end indices in the flat ftov messages. |
| 41 | + factor_type_to_potentials_range: Maps factor types to their start and end indices in the flat log potentials. |
| 42 | + factor_group_to_potentials_starts: Maps factor groups to their starting indices in the flat log potentials. |
| 43 | + log_potentials: Flat log potentials array concatenated for each factor type. |
| 44 | + wiring: Wiring derived for each factor type. |
| 45 | + """ |
| 46 | + |
| 47 | + variable_groups: Sequence[vgroup.VarGroup] |
| 48 | + vars_to_starts: Mapping[Tuple[int, int], int] |
| 49 | + num_var_states: int |
| 50 | + total_factor_num_states: int |
| 51 | + factor_type_to_msgs_range: OrderedDict[type, Tuple[int, int]] |
| 52 | + factor_type_to_potentials_range: OrderedDict[type, Tuple[int, int]] |
| 53 | + factor_group_to_potentials_starts: OrderedDict[fgroup.FactorGroup, int] |
| 54 | + log_potentials: OrderedDict[type, Union[None, np.ndarray]] |
| 55 | + wiring: OrderedDict[type, factor.Wiring] |
| 56 | + |
| 57 | + def __post_init__(self): |
| 58 | + for field in self.__dataclass_fields__: |
| 59 | + if isinstance(getattr(self, field), np.ndarray): |
| 60 | + getattr(self, field).flags.writeable = False |
| 61 | + |
| 62 | + if isinstance(getattr(self, field), Mapping): |
| 63 | + object.__setattr__(self, field, MappingProxyType(getattr(self, field))) |
| 64 | + |
| 65 | + |
31 | 66 | @dataclass
|
32 | 67 | class FactorGraph:
|
33 | 68 | """Class for representing a factor graph.
|
@@ -294,40 +329,3 @@ def bp_state(self) -> Any:
|
294 | 329 | ftov_msgs=bp_state.FToVMessages(fg_state=self.fg_state),
|
295 | 330 | evidence=bp_state.Evidence(fg_state=self.fg_state),
|
296 | 331 | )
|
297 |
| - |
298 |
| - |
299 |
| -@dataclass(frozen=True, eq=False) |
300 |
| -class FactorGraphState: |
301 |
| - """FactorGraphState. |
302 |
| -
|
303 |
| - Args: |
304 |
| - variable_groups: VarGroups in the FactorGraph. |
305 |
| - vars_to_starts: Maps variables to their starting indices in the flat evidence array. |
306 |
| - flat_evidence[vars_to_starts[variable]: vars_to_starts[variable] + variable.num_var_states] |
307 |
| - contains evidence to the variable. |
308 |
| - num_var_states: Total number of variable states. |
309 |
| - total_factor_num_states: Size of the flat ftov messages array. |
310 |
| - factor_type_to_msgs_range: Maps factors types to their start and end indices in the flat ftov messages. |
311 |
| - factor_type_to_potentials_range: Maps factor types to their start and end indices in the flat log potentials. |
312 |
| - factor_group_to_potentials_starts: Maps factor groups to their starting indices in the flat log potentials. |
313 |
| - log_potentials: Flat log potentials array concatenated for each factor type. |
314 |
| - wiring: Wiring derived for each factor type. |
315 |
| - """ |
316 |
| - |
317 |
| - variable_groups: Sequence[vgroup.VarGroup] |
318 |
| - vars_to_starts: Mapping[Tuple[int, int], int] |
319 |
| - num_var_states: int |
320 |
| - total_factor_num_states: int |
321 |
| - factor_type_to_msgs_range: OrderedDict[type, Tuple[int, int]] |
322 |
| - factor_type_to_potentials_range: OrderedDict[type, Tuple[int, int]] |
323 |
| - factor_group_to_potentials_starts: OrderedDict[fgroup.FactorGroup, int] |
324 |
| - log_potentials: OrderedDict[type, None | np.ndarray] |
325 |
| - wiring: OrderedDict[type, factor.Wiring] |
326 |
| - |
327 |
| - def __post_init__(self): |
328 |
| - for field in self.__dataclass_fields__: |
329 |
| - if isinstance(getattr(self, field), np.ndarray): |
330 |
| - getattr(self, field).flags.writeable = False |
331 |
| - |
332 |
| - if isinstance(getattr(self, field), Mapping): |
333 |
| - object.__setattr__(self, field, MappingProxyType(getattr(self, field))) |
0 commit comments