|
1 | 1 | from abc import ABC, abstractmethod |
2 | 2 | from typing import Tuple |
3 | 3 |
|
4 | | -from graphblas.core.matrix import Matrix |
5 | | -from graphblas.core.dtypes import BOOL |
6 | | -import numpy as np |
7 | 4 | import graphblas |
| 5 | +import pytest |
| 6 | +from graphblas.core.dtypes import BOOL |
| 7 | +from graphblas.core.matrix import Matrix |
8 | 8 |
|
9 | 9 | from cfpq_decomposer.decomposer import Decomposer |
10 | | - |
11 | | -class TestAbstractDecomposer(ABC): |
| 10 | +from test.cfpq_decomposer.synthetic_data import ( |
| 11 | + similar_rows_matrix, |
| 12 | + multiple_patterns_matrix, |
| 13 | + double_threshold_matrix, |
| 14 | + similar_columns_matrix, |
| 15 | + random_matrix_with_patterns, MIN_COMPRESSION_FACTORS |
| 16 | +) |
| 17 | + |
| 18 | +class AbstractDecomposerTest(ABC): |
12 | 19 | @abstractmethod |
13 | 20 | def create_decomposer(self) -> Decomposer: |
14 | 21 | pass |
15 | 22 |
|
16 | 23 | def decompose(self, matrix: Matrix) -> Tuple[Matrix, Matrix]: |
17 | 24 | return self.create_decomposer().decompose(matrix) |
18 | 25 |
|
19 | | - def test_decompose_similar_rows_matrix(self): |
20 | | - nrows, ncols = 15, 15 |
21 | | - M = Matrix(BOOL, nrows=nrows, ncols=ncols) |
22 | | - base_row_indices = [0, 1, 2, 3, 4, 5] |
23 | | - for i in range(10): |
24 | | - if i % 3 == 0: |
25 | | - M[i, 3] = True |
26 | | - for j in base_row_indices: |
27 | | - M[i, j] = True |
28 | | - for i in range(10, nrows): |
29 | | - M[i, i % ncols] = True |
30 | | - LEFT, RIGHT = self.decompose(M) |
31 | | - LEFT_RIGHT: Matrix = LEFT.mxm(RIGHT, op=graphblas.semiring.any_pair).new(dtype=BOOL) |
32 | | - |
33 | | - assert LEFT_RIGHT.dup(mask=~M.S).nvals == 0 |
34 | | - assert LEFT_RIGHT.nvals >= 18 |
35 | | - |
36 | | - def test_decompose_multiple_patterns(self): |
37 | | - nrows, ncols = 300, 100 |
38 | | - M = Matrix(BOOL, nrows=nrows, ncols=ncols) |
39 | | - |
40 | | - pattern1_cols = set(range(20)) |
41 | | - for i in range(100): |
42 | | - for j in pattern1_cols: |
43 | | - M[i, j] = True |
44 | | - if i % 10 == 0: |
45 | | - M[i, 25] = True |
46 | | - |
47 | | - pattern2_cols = set(range(30, 50)) |
48 | | - for i in range(100, 200): |
49 | | - for j in pattern2_cols: |
50 | | - M[i, j] = True |
51 | | - if i % 15 == 0: |
52 | | - M[i, 55] = True |
53 | | - |
54 | | - pattern3_cols = set(range(60, 80)) |
55 | | - for i in range(200, 300): |
56 | | - for j in pattern3_cols: |
57 | | - M[i, j] = True |
58 | | - if i % 20 == 0: |
59 | | - M[i, 85] = True |
60 | | - |
61 | | - LEFT, RIGHT = self.decompose(M) |
62 | | - LEFT_RIGHT = LEFT.mxm(RIGHT, op=graphblas.semiring.any_pair).new(dtype=BOOL) |
63 | | - |
64 | | - assert M.nvals == 6022 |
65 | | - assert LEFT_RIGHT.nvals >= 6000 |
66 | | - assert (LEFT_RIGHT | LEFT_RIGHT).new(mask=~M.S).nvals == 0 |
67 | | - |
68 | | - def test_decompose_double_thresholding(self): |
69 | | - nrows, ncols = 100, 50 |
70 | | - M = Matrix(BOOL, nrows=nrows, ncols=ncols) |
71 | | - |
72 | | - for i in range(nrows): |
73 | | - for j in range(10): |
74 | | - M[i, j] = True |
75 | | - |
76 | | - for i in range(75): |
77 | | - for j in range(10, 20): |
78 | | - M[i, j] = True |
79 | | - |
80 | | - for i in range(74): |
81 | | - for j in range(20, 30): |
82 | | - M[i, j] = True |
83 | | - |
84 | | - for i in range(76): |
85 | | - for j in range(30, 40): |
86 | | - M[i, j] = True |
87 | | - |
88 | | - for i in range(80): |
89 | | - for j in range(40, 50): |
90 | | - M[i, j] = True |
91 | | - |
92 | | - # Call the decompose function |
93 | | - LEFT, RIGHT = self.decompose(M) |
94 | | - LEFT_RIGHT = LEFT.mxm(RIGHT, op=graphblas.semiring.any_pair).new(dtype=BOOL) |
95 | | - |
96 | | - assert M.nvals == 4050 |
97 | | - assert LEFT_RIGHT.nvals >= 4000 |
98 | | - assert (LEFT_RIGHT | LEFT_RIGHT).new(mask=~M.S).nvals == 0 |
99 | | - |
100 | | - def test_decompose_similar_columns_without_transpose(self): |
101 | | - nrows, ncols = 100, 200 |
102 | | - M: Matrix = Matrix(BOOL, nrows=nrows, ncols=ncols) |
103 | | - |
104 | | - for j in range(50): |
105 | | - for i in range(80): |
106 | | - M[i, j] = True |
107 | | - if j % 10 == 0: |
108 | | - for i in range(80, 85): |
109 | | - M[i, j] = True |
110 | | - |
111 | | - for i in range(nrows): |
112 | | - for _ in range(5): |
113 | | - j = np.random.randint(50, ncols) |
114 | | - M[i, j] = True |
115 | | - |
116 | | - LEFT, RIGHT = self.decompose(M) |
117 | | - LEFT_RIGHT = LEFT.mxm(RIGHT, op=graphblas.semiring.any_pair).new(dtype=BOOL) |
118 | | - |
119 | | - assert M.nvals in range(4400, 4600) |
120 | | - assert LEFT_RIGHT.nvals >= 3900 |
121 | | - assert (LEFT_RIGHT | LEFT_RIGHT).new(mask=~M.S).nvals == 0 |
122 | | - |
123 | | - def test_decompose_random_matrix_with_patterns(self): |
124 | | - nrows, ncols = 500, 500 |
125 | | - M = Matrix(BOOL, nrows=nrows, ncols=ncols) |
126 | | - |
127 | | - for group in range(5): |
128 | | - row_start = group * 100 |
129 | | - row_end = row_start + 100 |
130 | | - cols = np.random.choice(ncols, size=50, replace=False) |
131 | | - for i in range(row_start, row_end): |
132 | | - for j in cols: |
133 | | - M[i, j] = True |
134 | | - if i % 25 == 0: |
135 | | - extra_cols = np.random.choice(ncols, size=5, replace=False) |
136 | | - for j in extra_cols: |
137 | | - M[i, j] = True |
138 | | - |
139 | | - for group in range(5): |
140 | | - col_start = group * 100 |
141 | | - col_end = col_start + 100 |
142 | | - rows = np.random.choice(nrows, size=50, replace=False) |
143 | | - for j in range(col_start, col_end): |
144 | | - for i in rows: |
145 | | - M[i, j] = True |
146 | | - if j % 25 == 0: |
147 | | - extra_rows = np.random.choice(nrows, size=5, replace=False) |
148 | | - for i in extra_rows: |
149 | | - M[i, j] = True |
150 | | - |
151 | | - num_noise_entries = int(M.nvals * 0.05) |
152 | | - for _ in range(num_noise_entries): |
153 | | - i = np.random.randint(0, nrows) |
154 | | - j = np.random.randint(0, ncols) |
155 | | - M[i, j] = True |
156 | | - |
157 | | - LEFT, RIGHT = self.decompose(M) |
158 | | - LEFT_RIGHT = LEFT.mxm(RIGHT, op=graphblas.semiring.any_pair).new(dtype=BOOL) |
159 | | - |
160 | | - assert M.nvals in range(48_000, 52_000) |
161 | | - assert LEFT_RIGHT.nvals >= 33_000 |
162 | | - assert (LEFT_RIGHT | LEFT_RIGHT).new(mask=~M.S).nvals == 0 |
| 26 | + @pytest.mark.CI |
| 27 | + @pytest.mark.parametrize("matrix_fn,key", [ |
| 28 | + pytest.param(similar_rows_matrix, "similar_rows", id="similar_rows"), |
| 29 | + pytest.param(multiple_patterns_matrix, "multiple_patterns", id="multiple_patterns"), |
| 30 | + pytest.param(double_threshold_matrix, "double_threshold", id="double_threshold"), |
| 31 | + pytest.param(similar_columns_matrix, "similar_columns", id="similar_columns"), |
| 32 | + pytest.param(random_matrix_with_patterns, "random_patterns", id="random_patterns"), |
| 33 | + ]) |
| 34 | + def test_decomposition(self, matrix_fn, key): |
| 35 | + matrix = matrix_fn() |
| 36 | + left, right = self.decompose(matrix) |
| 37 | + left_right = left.mxm(right, op=graphblas.semiring.any_pair).new(dtype=BOOL) |
| 38 | + assert left_right.dup(mask=~matrix.S).nvals == 0 |
| 39 | + remainder = matrix.dup(mask=~left_right.S).nvals |
| 40 | + compression_factor = matrix.nvals / (left.nvals + right.nvals + remainder) |
| 41 | + print(f"Compression factor for {key} is {compression_factor}") |
| 42 | + assert compression_factor > MIN_COMPRESSION_FACTORS[key] |
0 commit comments