Skip to content

Commit 21fdc4b

Browse files
committed
Add prototype implementation of matrix decomposition and tests
1 parent 4619a59 commit 21fdc4b

File tree

2 files changed

+397
-2
lines changed

2 files changed

+397
-2
lines changed

cfpq_matrix/matrix_utils.py

+241-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1-
from typing import Any
1+
import random
2+
from collections import defaultdict
3+
from typing import Any, Tuple
24

35
import graphblas
4-
from graphblas.core.dtypes import DataType
6+
import numpy as np
7+
from graphblas.binary import plus
8+
from graphblas.core.dtypes import DataType, BOOL, INT32
59
from graphblas.core.matrix import Matrix
610
from graphblas.core.vector import Vector
711

@@ -22,3 +26,238 @@ def identity_matrix(one: Any, dtype: DataType, size: int) -> Matrix:
2226
size=size,
2327
dtype=dtype
2428
).diag()
29+
30+
def expand_matrix(matrix: Matrix, new_shape: Tuple[int, int]) -> Matrix:
31+
(rows, columns, values) = matrix.to_coo()
32+
return Matrix.from_coo(rows, columns, values, dtype=matrix.dtype, nrows=new_shape[0], ncols=new_shape[1])
33+
34+
def row_based_decompose(M: Matrix):
35+
"""
36+
Decomposes a sparse boolean matrix M into LEFT, RIGHT, and M' such that M = LEFT * RIGHT + M'.
37+
38+
Parameters:
39+
M (gb.Matrix): Input sparse boolean matrix.
40+
41+
Returns:
42+
LEFT (gb.Matrix): Left factor matrix.
43+
RIGHT (gb.Matrix): Right factor matrix.
44+
M_prime (gb.Matrix): Remainder matrix after decomposition.
45+
"""
46+
n_rows, n_cols = M.shape
47+
48+
I, J, V = M.to_coo()
49+
50+
rows = defaultdict(set)
51+
for i, j in zip(I, J):
52+
rows[i].add(j)
53+
54+
p = 2147483647
55+
num_hashes = 5 # TODO 2 or 3 is probably better for real world data
56+
hash_funcs = []
57+
for _ in range(num_hashes):
58+
a = random.randint(1, p - 1)
59+
b = random.randint(0, p - 1)
60+
hash_funcs.append((a, b))
61+
62+
minhashes = dict()
63+
64+
for i, S_i in rows.items():
65+
minhash_values = []
66+
if len(S_i) < 5:
67+
continue
68+
for a, b in hash_funcs:
69+
min_hash = min(((a * x + b) % p) for x in S_i)
70+
minhash_values.append(min_hash)
71+
minhashes[i] = tuple(minhash_values)
72+
73+
master_hashes = dict()
74+
for i, minhash_values in minhashes.items():
75+
master_hash = hash(minhash_values)
76+
master_hashes[i] = master_hash
77+
78+
buckets = defaultdict(list)
79+
for i, master_hash in master_hashes.items():
80+
buckets[master_hash].append(i)
81+
82+
buckets = {h: idxs for h, idxs in buckets.items() if len(idxs) >= 5}
83+
84+
LEFT_columns = []
85+
RIGHT_rows = []
86+
87+
for h, B in buckets.items():
88+
N = len(B)
89+
M_B: Matrix = M[B, :].new()
90+
A1 = M_B.dup(dtype=INT32).reduce_columnwise(plus).new()
91+
92+
threshold = int(0.95 * N)
93+
A2: Vector = A1.select('>=', threshold).new()
94+
95+
if A2.nvals == 0:
96+
continue
97+
98+
S_A2 = set(A2.to_coo()[0])
99+
100+
B_prime = [i for i in B if S_A2 <= rows[i]]
101+
102+
K = len(B_prime)
103+
if K == 0:
104+
continue
105+
106+
M_B_prime = M[B_prime, :].new()
107+
A3 = M_B_prime.dup(dtype=INT32).reduce_columnwise(plus)
108+
109+
threshold = int(0.95 * K)
110+
A4 = A3.select('>=', threshold).new()
111+
112+
if A4.nvals == 0:
113+
continue
114+
115+
S_A4 = set(A4.to_coo()[0])
116+
117+
B_double_prime = [i for i in B_prime if S_A4 <= rows[i]]
118+
119+
if len(B_double_prime) < 5:
120+
continue
121+
122+
RIGHT_rows.append(A4)
123+
124+
CORE = Vector(BOOL, size=n_rows)
125+
for i in B_double_prime:
126+
CORE[i] = True
127+
LEFT_columns.append(CORE)
128+
129+
num_buckets_remaining = len(LEFT_columns)
130+
if num_buckets_remaining == 0:
131+
return Matrix(M.dtype, M.nrows, 0), Matrix(M.dtype, 0, M.ncols)
132+
133+
LEFT = Matrix(bool, n_rows, num_buckets_remaining)
134+
for idx, CORE in enumerate(LEFT_columns):
135+
LEFT[:, idx] = CORE
136+
137+
RIGHT = Matrix(bool, num_buckets_remaining, n_cols)
138+
for idx, A4 in enumerate(RIGHT_rows):
139+
RIGHT[idx, :] = A4
140+
141+
return LEFT, RIGHT
142+
143+
def column_based_decompose(M: Matrix):
144+
LEFT_T, RIGHT_T = row_based_decompose(M.T.new())
145+
return RIGHT_T.T.new(), LEFT_T.T.new()
146+
147+
def decompose(M: Matrix):
148+
accumulated_LEFT = []
149+
accumulated_RIGHT = []
150+
iteration = 0
151+
152+
init_nvals = M.nvals
153+
if init_nvals == 0:
154+
return Matrix(M.dtype, M.nrows, 0), Matrix(M.dtype, 0, M.ncols)
155+
156+
while True:
157+
iteration += 1
158+
nvals_before = M.nvals
159+
160+
LEFT1, RIGHT1 = row_based_decompose(M)
161+
162+
if LEFT1.nvals != 0:
163+
M = M.dup(mask=~LEFT1.mxm(RIGHT1, op=graphblas.semiring.any_pair).new(dtype=BOOL).S)
164+
165+
LEFT2, RIGHT2 = column_based_decompose(M)
166+
167+
if LEFT2.nvals != 0:
168+
M = M.dup(mask=~LEFT2.mxm(RIGHT2, op=graphblas.semiring.any_pair).new(dtype=BOOL).S)
169+
170+
nvals_LEFT_RIGHT = LEFT1.nvals + RIGHT1.nvals + LEFT2.nvals + RIGHT2.nvals
171+
172+
nvals_after = M.nvals
173+
delta_M = nvals_before - nvals_after
174+
175+
reduction_ratio = delta_M / nvals_before if nvals_before > 0 else 0
176+
size_ratio = nvals_LEFT_RIGHT / delta_M if delta_M > 0 else float('inf')
177+
178+
accumulated_LEFT.extend([LEFT1, LEFT2])
179+
accumulated_RIGHT.extend([RIGHT1, RIGHT2])
180+
181+
if reduction_ratio < 0.05 or size_ratio > 0.3:
182+
break
183+
184+
if M.nvals == 0:
185+
break
186+
187+
if not accumulated_LEFT or not accumulated_RIGHT:
188+
return Matrix(BOOL, nrows=M.nrows, ncols=0), Matrix(BOOL, nrows=0, ncols=M.ncols)
189+
190+
LEFT = stack([accumulated_LEFT])
191+
RIGHT = stack([[RIGHT] for RIGHT in accumulated_RIGHT])
192+
193+
return LEFT, RIGHT
194+
195+
def stack(matrix_grid: list[list[Matrix]]) -> Matrix:
196+
"""
197+
Stack a 2D list of matrices into a single larger matrix.
198+
Vertically stacks matrices within each row of the list, and then horizontally stacks the results.
199+
200+
Parameters:
201+
matrix_grid (list[list[Matrix]]): A 2D list of matrices to stack.
202+
203+
Returns:
204+
Matrix: The stacked matrix.
205+
"""
206+
if not matrix_grid or not matrix_grid[0]:
207+
raise ValueError("The matrix grid cannot be empty.")
208+
209+
num_cols = len(matrix_grid[0])
210+
for row in matrix_grid:
211+
if len(row) != num_cols:
212+
raise ValueError("All rows in the matrix grid must have the same number of matrices.")
213+
214+
for row in matrix_grid:
215+
row_height = row[0].nrows
216+
for matrix in row:
217+
if matrix.nrows != row_height:
218+
raise ValueError("All matrices in the same row must have the same number of rows.")
219+
220+
for col in range(num_cols):
221+
col_width = matrix_grid[0][col].ncols
222+
for row in matrix_grid:
223+
if row[col].ncols != col_width:
224+
raise ValueError("All matrices in the same column must have the same number of columns.")
225+
226+
combined_rows = []
227+
combined_columns = []
228+
combined_values = []
229+
230+
current_row_offset = 0
231+
232+
for row in matrix_grid:
233+
current_col_offset = 0
234+
235+
for matrix in row:
236+
M_I, M_J, M_V = matrix.to_coo()
237+
238+
adjusted_rows = M_I + current_row_offset
239+
adjusted_columns = M_J + current_col_offset
240+
241+
combined_rows.append(adjusted_rows)
242+
combined_columns.append(adjusted_columns)
243+
combined_values.append(M_V)
244+
245+
current_col_offset += matrix.ncols
246+
247+
current_row_offset += row[0].nrows
248+
249+
final_rows = np.concatenate(combined_rows)
250+
final_columns = np.concatenate(combined_columns)
251+
final_values = np.concatenate(combined_values)
252+
253+
total_rows = current_row_offset
254+
total_columns = sum(matrix.ncols for matrix in matrix_grid[0])
255+
256+
return Matrix.from_coo(
257+
rows=final_rows,
258+
columns=final_columns,
259+
values=final_values,
260+
dtype=matrix_grid[0][0].dtype,
261+
nrows=total_rows,
262+
ncols=total_columns,
263+
)

0 commit comments

Comments
 (0)