2
2
import typing
3
3
from dataclasses import dataclass , field
4
4
from types import MappingProxyType
5
- from typing import Any , Dict , Hashable , List , Mapping , Sequence , Tuple , Union
5
+ from typing import Any , Dict , Hashable , List , Mapping , Optional , Sequence , Tuple , Union
6
6
7
7
import numpy as np
8
8
12
12
13
13
@dataclass (frozen = True , eq = False )
14
14
class VariableGroup :
15
- """Base class to represent a group of variables.
15
+ """Class to represent a group of variables.
16
16
17
17
All variables in the group are assumed to have the same size. Additionally, the
18
18
variables are indexed by a "key", and can be retrieved by direct indexing (even indexing
@@ -32,11 +32,11 @@ def __post_init__(self) -> None:
32
32
33
33
@typing .overload
34
34
def __getitem__ (self , key : Hashable ) -> nodes .Variable :
35
- pass
35
+ """This function is a typing overload and is overwritten by the implemented __getitem__"""
36
36
37
37
@typing .overload
38
38
def __getitem__ (self , key : List ) -> List [nodes .Variable ]:
39
- pass
39
+ """This function is a typing overload and is overwritten by the implemented __getitem__"""
40
40
41
41
def __getitem__ (self , key ):
42
42
"""Given a key, retrieve the associated Variable.
@@ -133,24 +133,17 @@ class CompositeVariableGroup(VariableGroup):
133
133
]
134
134
135
135
def __post_init__ (self ):
136
- if (not isinstance (self .variable_group_container , Mapping )) and (
137
- not isinstance (self .variable_group_container , Sequence )
138
- ):
139
- raise ValueError (
140
- f"variable_group_container needs to be a mapping or a sequence. Got { type (self .variable_group_container )} "
141
- )
142
-
143
136
object .__setattr__ (
144
137
self , "_keys_to_vars" , MappingProxyType (self ._set_keys_to_vars ())
145
138
)
146
139
147
140
@typing .overload
148
141
def __getitem__ (self , key : Hashable ) -> nodes .Variable :
149
- pass
142
+ """This function is a typing overload and is overwritten by the implemented __getitem__"""
150
143
151
144
@typing .overload
152
145
def __getitem__ (self , key : List ) -> List [nodes .Variable ]:
153
- pass
146
+ """This function is a typing overload and is overwritten by the implemented __getitem__"""
154
147
155
148
def __getitem__ (self , key ):
156
149
"""Given a key, retrieve the associated Variable from the associated VariableGroup.
@@ -213,7 +206,7 @@ def get_vars_to_evidence(
213
206
214
207
Args:
215
208
evidence: A mapping or a sequence of evidences.
216
- The type of evidence should match that of self.variable_group_container
209
+ The type of evidence should match that of self.variable_group_container.
217
210
218
211
Returns:
219
212
a dictionary mapping all possible variables to the corresponding evidence
@@ -344,7 +337,7 @@ def get_vars_to_evidence(
344
337
345
338
if evidence [key ].shape != (self .variable_size ,):
346
339
raise ValueError (
347
- f"Variable { key } expect an evidence array of shape "
340
+ f"Variable { key } expects an evidence array of shape "
348
341
f"({ (self .variable_size ,)} )."
349
342
f"Got { evidence [key ].shape } ."
350
343
)
@@ -356,14 +349,14 @@ def get_vars_to_evidence(
356
349
357
350
@dataclass (frozen = True , eq = False )
358
351
class FactorGroup :
359
- """Base class to represent a group of factors.
352
+ """Class to represent a group of factors.
360
353
361
354
Args:
362
355
variable_group: either a VariableGroup or - if the elements of more than one VariableGroup
363
356
are connected to this FactorGroup - then a CompositeVariableGroup. This holds
364
357
all the variables that are connected to this FactorGroup
365
- connected_var_keys: A list of tuples of tuples, where each innermost tuple contains a
366
- key variable_group. Each list within the outer list is taken to contain the keys of variables
358
+ connected_var_keys: A list of list of tuples, where each innermost tuple contains a
359
+ key into variable_group. Each list within the outer list is taken to contain the keys of variables
367
360
neighboring a particular factor to be added.
368
361
369
362
Raises:
@@ -385,7 +378,7 @@ def factors(self) -> Tuple[nodes.EnumerationFactor, ...]:
385
378
386
379
@dataclass (frozen = True , eq = False )
387
380
class EnumerationFactorGroup (FactorGroup ):
388
- """Base class to represent a group of EnumerationFactors.
381
+ """Class to represent a group of EnumerationFactors.
389
382
390
383
All factors in the group are assumed to have the same set of valid configurations and
391
384
the same potential function. Note that the log potential function is assumed to be
@@ -398,27 +391,24 @@ class EnumerationFactorGroup(FactorGroup):
398
391
Attributes:
399
392
factors: a tuple of all the factors belonging to this group. These are constructed
400
393
internally by invoking the _get_connected_var_keys_for_factors method.
401
- factor_configs_log_potentials: Can be specified by an inheriting class, or just left
402
- unspecified (equivalent to specifying None). If specified, must have (num_val_configs,).
403
- and contain the log of the potential value for every possible configuration.
404
- If none, it is assumed the log potential is uniform 0 and such an array is automatically
405
- initialized.
406
-
394
+ factor_configs_log_potentials: Optional ndarray of shape (num_val_configs,).
395
+ if specified. Must contain the log of the potential value for every possible
396
+ configuration. If left unspecified, it is assumed the log potential is uniform
397
+ 0 and such an array is automatically initialized.
407
398
"""
408
399
409
400
factor_configs : np .ndarray
401
+ factor_configs_log_potentials : Optional [np .ndarray ] = None
410
402
411
403
@cached_property
412
404
def factors (self ) -> Tuple [nodes .EnumerationFactor , ...]:
413
405
"""Returns a tuple of all the factors contained within this FactorGroup."""
414
- if getattr ( self , " factor_configs_log_potentials" , None ) is None :
406
+ if self . factor_configs_log_potentials is None :
415
407
factor_configs_log_potentials = np .zeros (
416
408
self .factor_configs .shape [0 ], dtype = float
417
409
)
418
410
else :
419
- factor_configs_log_potentials = getattr (
420
- self , "factor_configs_log_potentials"
421
- )
411
+ factor_configs_log_potentials = self .factor_configs_log_potentials
422
412
423
413
return tuple (
424
414
[
@@ -434,7 +424,7 @@ def factors(self) -> Tuple[nodes.EnumerationFactor, ...]:
434
424
435
425
@dataclass (frozen = True , eq = False )
436
426
class PairwiseFactorGroup (FactorGroup ):
437
- """Base class to represent a group of EnumerationFactors where each factor connects to
427
+ """Class to represent a group of EnumerationFactors where each factor connects to
438
428
two different variables.
439
429
440
430
All factors in the group are assumed to be such that all possible configuration of the two
0 commit comments