Skip to content
This repository was archived by the owner on Dec 5, 2024. It is now read-only.

Commit f8555a4

Browse files
Adds unit tests (#41)
* creates unit tests for bp_utils * creates E2E integration tests for pgmax * adds codecov to CI * updates ci and poetry deps * ci now uses pytest-cov to generate reports * removes coverage requirement from poetry dev reqs * updates poetry to depend on strictly python 3.7 (others cause problems with pre-commit, etc.) * minor update to gitignore * updates tests to 100% coverage * updates isort config to make isort play nicely with black Co-authored-by: Guangyao Zhou <[email protected]>
1 parent a6be2a6 commit f8555a4

18 files changed

+657
-314
lines changed

.coveragerc

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
[report]
2+
exclude_lines =
3+
# Have to re-enable the standard pragma
4+
pragma: no cover
5+
6+
# Don't complain if tests don't hit defensive assertion code:
7+
raise NotImplementedError

.github/workflows/ci.yaml

+10-4
Original file line numberDiff line numberDiff line change
@@ -63,16 +63,22 @@ jobs:
6363
- name: Install library
6464
run: poetry install --no-interaction
6565
#----------------------------------------------
66-
# run test suite
66+
# run test suite with coverage
6767
#----------------------------------------------
68-
- name: Test with pytest
68+
- name: Test with coverage
6969
run: |
70-
poetry run pytest
70+
poetry run pytest --cov=pgmax --cov-report=xml
71+
#----------------------------------------------
72+
# upload coverage report to codecov
73+
#----------------------------------------------
74+
- name: Upload Coverage to Codecov
75+
uses: codecov/codecov-action@v2
76+
with:
77+
verbose: true # optional (default = false)
7178
#----------------------------------------------
7279
# test docs build only on PR
7380
#----------------------------------------------
7481
- name: Test docs build
75-
if: ${{ github.event_name == 'pull_request' }}
7682
run: |
7783
cd docs
7884
poetry run make html

.gitignore

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
.vscode/
2-
31
# Byte-compiled / optimized / DLL files
42
__pycache__/
53
*.py[cod]
@@ -131,3 +129,7 @@ dmypy.json
131129

132130
# Pyre type checker
133131
.pyre/
132+
.ruby-version
133+
134+
# VSCode settings
135+
.vscode/

.isort.cfg

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[settings]
2+
profile=black

codecov.yaml

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
ignore:
2+
- "docs/**/*"
3+
- "tests/**/*"
4+
- "examples/**/*"

examples/heretic_example.py

+16-17
Original file line numberDiff line numberDiff line change
@@ -8,32 +8,31 @@
88
# format_version: '1.3'
99
# jupytext_version: 1.11.4
1010
# kernelspec:
11-
# display_name: 'Python 3.8.5 64-bit (''pgmax-JcKb81GE-py3.8'': poetry)'
12-
# name: python3
11+
# display_name: 'Python 3.7.11 64-bit (''pgmax-zIh0MZVc-py3.7'': venv)'
12+
# name: python371164bitpgmaxzih0mzvcpy37venve540bb1b5cdf4292a3f5a12c4904cc40
1313
# ---
1414

15+
from timeit import default_timer as timer
16+
from typing import Any, List, Tuple
17+
18+
import jax
19+
import jax.numpy as jnp
20+
1521
# %%
1622
# %matplotlib inline
17-
# fmt: off
18-
1923
# Standard Package Imports
20-
import matplotlib.pyplot as plt # isort:skip
21-
import numpy as np # isort:skip
22-
import jax # isort:skip
23-
import jax.numpy as jnp # isort:skip
24-
from typing import Any, Tuple, List # isort:skip
25-
from timeit import default_timer as timer # isort:skip
24+
import matplotlib.pyplot as plt
25+
import numpy as np
2626

27-
# Custom Imports
28-
import pgmax.fg.groups as groups # isort:skip
29-
import pgmax.fg.graph as graph # isort:skip
27+
import pgmax.fg.graph as graph
3028

31-
# fmt: on
29+
# Custom Imports
30+
import pgmax.fg.groups as groups
3231

3332
# %% [markdown]
3433
# # Setup Variables
3534

36-
# %%
35+
# %% tags=[]
3736
# Define some global constants
3837
im_size = (30, 30)
3938
prng_key = jax.random.PRNGKey(42)
@@ -100,7 +99,7 @@
10099
# %% [markdown]
101100
# # Add all Factors to graph via constructing FactorGroups
102101

103-
# %%
102+
# %% tags=[]
104103
def binary_connected_variables(
105104
num_hidden_rows, num_hidden_cols, kernel_row, kernel_col
106105
):
@@ -182,7 +181,7 @@ def custom_flatten_ordering(Mdown, Mup):
182181
# %% [markdown]
183182
# # Run Belief Propagation and Retrieve MAP Estimate
184183

185-
# %%
184+
# %% tags=[]
186185
# Run BP
187186
bp_start_time = timer()
188187
final_msgs = fg.run_bp(

examples/sanity_check_example.py

+10-13
Original file line numberDiff line numberDiff line change
@@ -14,24 +14,21 @@
1414

1515
# %%
1616
# %matplotlib inline
17-
# fmt: off
1817
import os
18+
from timeit import default_timer as timer
19+
from typing import Any, Dict, List, Tuple
20+
21+
# Standard Package Imports
22+
import matplotlib.pyplot as plt
23+
import numpy as np
24+
from numpy.random import default_rng
25+
from scipy import sparse
26+
from scipy.ndimage import gaussian_filter
1927

2028
import pgmax.fg.graph as graph
2129

2230
# Custom Imports
23-
import pgmax.fg.groups as groups # isort:skip
24-
25-
# Standard Package Imports
26-
import matplotlib.pyplot as plt # isort:skip
27-
import numpy as np # isort:skip
28-
from numpy.random import default_rng # isort:skip
29-
from scipy import sparse # isort:skip
30-
from scipy.ndimage import gaussian_filter # isort:skip
31-
from typing import Any, Dict, Tuple, List # isort:skip
32-
from timeit import default_timer as timer # isort:skip
33-
34-
# fmt: on
31+
import pgmax.fg.groups as groups
3532

3633
# %% [markdown]
3734
# ## Setting up Image

pgmax/bp/bp_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ def segment_max_opt(
1818
1919
Args:
2020
data: Array of shape (a_len,) where a_len is an arbitrary integer.
21-
segments_lengths: Array of shape (num_segments,), where num_segments <= a_len.
22-
segments_lengths.sum() should yield a_len.
21+
segments_lengths: Array of shape (num_segments,), where 0 < num_segments <= a_len.
22+
segments_lengths.sum() should yield a_len, and all elements must be > 0.
2323
max_segment_length: The maximum value in segments_lengths.
2424
2525
Returns:

pgmax/fg/graph.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,12 @@ def add_factors(
104104
**kwargs: optional mapping of keyword arguments. If specified, and if there
105105
is no "factor_factory" key specified as part of this mapping, then these
106106
args are taken to specify the arguments to be used to instantiate an
107-
EnumerationFactor. If there is a "factor_factory" key, then these args
108-
are taken to specify the arguments to be used to construct the class
109-
specified by the "factor_factory" argument. Note that either *args or
110-
**kwargs must be specified.
107+
EnumerationFactor (specify a kwarg with the key 'keys' to indicate the
108+
indices of variables ot be indexed to create the EnumerationFactor).
109+
If there is a "factor_factory" key, then these args are taken to specify
110+
the arguments to be used to construct the class specified by the
111+
"factor_factory" argument. Note that either *args or **kwargs must be
112+
specified.
111113
"""
112114
factor_factory = kwargs.pop("factor_factory", None)
113115
if factor_factory is not None:
@@ -168,7 +170,7 @@ def evidence(self) -> np.ndarray:
168170
if self.evidence_default_mode == "zeros":
169171
evidence = np.zeros(self.num_var_states)
170172
elif self.evidence_default_mode == "random":
171-
evidence = np.random.gumbel(self.num_var_states)
173+
evidence = np.random.gumbel(size=self.num_var_states)
172174
else:
173175
raise NotImplementedError(
174176
f"evidence_default_mode {self.evidence_default_mode} is not yet implemented"

pgmax/fg/groups.py

+20-30
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import typing
33
from dataclasses import dataclass, field
44
from types import MappingProxyType
5-
from typing import Any, Dict, Hashable, List, Mapping, Sequence, Tuple, Union
5+
from typing import Any, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union
66

77
import numpy as np
88

@@ -12,7 +12,7 @@
1212

1313
@dataclass(frozen=True, eq=False)
1414
class VariableGroup:
15-
"""Base class to represent a group of variables.
15+
"""Class to represent a group of variables.
1616
1717
All variables in the group are assumed to have the same size. Additionally, the
1818
variables are indexed by a "key", and can be retrieved by direct indexing (even indexing
@@ -32,11 +32,11 @@ def __post_init__(self) -> None:
3232

3333
@typing.overload
3434
def __getitem__(self, key: Hashable) -> nodes.Variable:
35-
pass
35+
"""This function is a typing overload and is overwritten by the implemented __getitem__"""
3636

3737
@typing.overload
3838
def __getitem__(self, key: List) -> List[nodes.Variable]:
39-
pass
39+
"""This function is a typing overload and is overwritten by the implemented __getitem__"""
4040

4141
def __getitem__(self, key):
4242
"""Given a key, retrieve the associated Variable.
@@ -133,24 +133,17 @@ class CompositeVariableGroup(VariableGroup):
133133
]
134134

135135
def __post_init__(self):
136-
if (not isinstance(self.variable_group_container, Mapping)) and (
137-
not isinstance(self.variable_group_container, Sequence)
138-
):
139-
raise ValueError(
140-
f"variable_group_container needs to be a mapping or a sequence. Got {type(self.variable_group_container)}"
141-
)
142-
143136
object.__setattr__(
144137
self, "_keys_to_vars", MappingProxyType(self._set_keys_to_vars())
145138
)
146139

147140
@typing.overload
148141
def __getitem__(self, key: Hashable) -> nodes.Variable:
149-
pass
142+
"""This function is a typing overload and is overwritten by the implemented __getitem__"""
150143

151144
@typing.overload
152145
def __getitem__(self, key: List) -> List[nodes.Variable]:
153-
pass
146+
"""This function is a typing overload and is overwritten by the implemented __getitem__"""
154147

155148
def __getitem__(self, key):
156149
"""Given a key, retrieve the associated Variable from the associated VariableGroup.
@@ -213,7 +206,7 @@ def get_vars_to_evidence(
213206
214207
Args:
215208
evidence: A mapping or a sequence of evidences.
216-
The type of evidence should match that of self.variable_group_container
209+
The type of evidence should match that of self.variable_group_container.
217210
218211
Returns:
219212
a dictionary mapping all possible variables to the corresponding evidence
@@ -344,7 +337,7 @@ def get_vars_to_evidence(
344337

345338
if evidence[key].shape != (self.variable_size,):
346339
raise ValueError(
347-
f"Variable {key} expect an evidence array of shape "
340+
f"Variable {key} expects an evidence array of shape "
348341
f"({(self.variable_size,)})."
349342
f"Got {evidence[key].shape}."
350343
)
@@ -356,14 +349,14 @@ def get_vars_to_evidence(
356349

357350
@dataclass(frozen=True, eq=False)
358351
class FactorGroup:
359-
"""Base class to represent a group of factors.
352+
"""Class to represent a group of factors.
360353
361354
Args:
362355
variable_group: either a VariableGroup or - if the elements of more than one VariableGroup
363356
are connected to this FactorGroup - then a CompositeVariableGroup. This holds
364357
all the variables that are connected to this FactorGroup
365-
connected_var_keys: A list of tuples of tuples, where each innermost tuple contains a
366-
key variable_group. Each list within the outer list is taken to contain the keys of variables
358+
connected_var_keys: A list of list of tuples, where each innermost tuple contains a
359+
key into variable_group. Each list within the outer list is taken to contain the keys of variables
367360
neighboring a particular factor to be added.
368361
369362
Raises:
@@ -385,7 +378,7 @@ def factors(self) -> Tuple[nodes.EnumerationFactor, ...]:
385378

386379
@dataclass(frozen=True, eq=False)
387380
class EnumerationFactorGroup(FactorGroup):
388-
"""Base class to represent a group of EnumerationFactors.
381+
"""Class to represent a group of EnumerationFactors.
389382
390383
All factors in the group are assumed to have the same set of valid configurations and
391384
the same potential function. Note that the log potential function is assumed to be
@@ -398,27 +391,24 @@ class EnumerationFactorGroup(FactorGroup):
398391
Attributes:
399392
factors: a tuple of all the factors belonging to this group. These are constructed
400393
internally by invoking the _get_connected_var_keys_for_factors method.
401-
factor_configs_log_potentials: Can be specified by an inheriting class, or just left
402-
unspecified (equivalent to specifying None). If specified, must have (num_val_configs,).
403-
and contain the log of the potential value for every possible configuration.
404-
If none, it is assumed the log potential is uniform 0 and such an array is automatically
405-
initialized.
406-
394+
factor_configs_log_potentials: Optional ndarray of shape (num_val_configs,).
395+
if specified. Must contain the log of the potential value for every possible
396+
configuration. If left unspecified, it is assumed the log potential is uniform
397+
0 and such an array is automatically initialized.
407398
"""
408399

409400
factor_configs: np.ndarray
401+
factor_configs_log_potentials: Optional[np.ndarray] = None
410402

411403
@cached_property
412404
def factors(self) -> Tuple[nodes.EnumerationFactor, ...]:
413405
"""Returns a tuple of all the factors contained within this FactorGroup."""
414-
if getattr(self, "factor_configs_log_potentials", None) is None:
406+
if self.factor_configs_log_potentials is None:
415407
factor_configs_log_potentials = np.zeros(
416408
self.factor_configs.shape[0], dtype=float
417409
)
418410
else:
419-
factor_configs_log_potentials = getattr(
420-
self, "factor_configs_log_potentials"
421-
)
411+
factor_configs_log_potentials = self.factor_configs_log_potentials
422412

423413
return tuple(
424414
[
@@ -434,7 +424,7 @@ def factors(self) -> Tuple[nodes.EnumerationFactor, ...]:
434424

435425
@dataclass(frozen=True, eq=False)
436426
class PairwiseFactorGroup(FactorGroup):
437-
"""Base class to represent a group of EnumerationFactors where each factor connects to
427+
"""Class to represent a group of EnumerationFactors where each factor connects to
438428
two different variables.
439429
440430
All factors in the group are assumed to be such that all possible configuration of the two

0 commit comments

Comments
 (0)