1
- from typing import Any
1
+ import random
2
+ from collections import defaultdict
3
+ from typing import Any , Tuple
2
4
3
5
import graphblas
4
- from graphblas .core .dtypes import DataType
6
+ import numpy as np
7
+ from graphblas .binary import plus
8
+ from graphblas .core .dtypes import DataType , BOOL , INT32
5
9
from graphblas .core .matrix import Matrix
6
10
from graphblas .core .vector import Vector
7
11
@@ -22,3 +26,238 @@ def identity_matrix(one: Any, dtype: DataType, size: int) -> Matrix:
22
26
size = size ,
23
27
dtype = dtype
24
28
).diag ()
29
+
30
+ def expand_matrix (matrix : Matrix , new_shape : Tuple [int , int ]) -> Matrix :
31
+ (rows , columns , values ) = matrix .to_coo ()
32
+ return Matrix .from_coo (rows , columns , values , dtype = matrix .dtype , nrows = new_shape [0 ], ncols = new_shape [1 ])
33
+
34
+ def row_based_decompose (M : Matrix ):
35
+ """
36
+ Decomposes a sparse boolean matrix M into LEFT, RIGHT, and M' such that M = LEFT * RIGHT + M'.
37
+
38
+ Parameters:
39
+ M (gb.Matrix): Input sparse boolean matrix.
40
+
41
+ Returns:
42
+ LEFT (gb.Matrix): Left factor matrix.
43
+ RIGHT (gb.Matrix): Right factor matrix.
44
+ M_prime (gb.Matrix): Remainder matrix after decomposition.
45
+ """
46
+ n_rows , n_cols = M .shape
47
+
48
+ I , J , V = M .to_coo ()
49
+
50
+ rows = defaultdict (set )
51
+ for i , j in zip (I , J ):
52
+ rows [i ].add (j )
53
+
54
+ p = 2147483647
55
+ num_hashes = 5 # TODO 2 or 3 is probably better for real world data
56
+ hash_funcs = []
57
+ for _ in range (num_hashes ):
58
+ a = random .randint (1 , p - 1 )
59
+ b = random .randint (0 , p - 1 )
60
+ hash_funcs .append ((a , b ))
61
+
62
+ minhashes = dict ()
63
+
64
+ for i , S_i in rows .items ():
65
+ minhash_values = []
66
+ if len (S_i ) < 5 :
67
+ continue
68
+ for a , b in hash_funcs :
69
+ min_hash = min (((a * x + b ) % p ) for x in S_i )
70
+ minhash_values .append (min_hash )
71
+ minhashes [i ] = tuple (minhash_values )
72
+
73
+ master_hashes = dict ()
74
+ for i , minhash_values in minhashes .items ():
75
+ master_hash = hash (minhash_values )
76
+ master_hashes [i ] = master_hash
77
+
78
+ buckets = defaultdict (list )
79
+ for i , master_hash in master_hashes .items ():
80
+ buckets [master_hash ].append (i )
81
+
82
+ buckets = {h : idxs for h , idxs in buckets .items () if len (idxs ) >= 5 }
83
+
84
+ LEFT_columns = []
85
+ RIGHT_rows = []
86
+
87
+ for h , B in buckets .items ():
88
+ N = len (B )
89
+ M_B : Matrix = M [B , :].new ()
90
+ A1 = M_B .dup (dtype = INT32 ).reduce_columnwise (plus ).new ()
91
+
92
+ threshold = int (0.95 * N )
93
+ A2 : Vector = A1 .select ('>=' , threshold ).new ()
94
+
95
+ if A2 .nvals == 0 :
96
+ continue
97
+
98
+ S_A2 = set (A2 .to_coo ()[0 ])
99
+
100
+ B_prime = [i for i in B if S_A2 <= rows [i ]]
101
+
102
+ K = len (B_prime )
103
+ if K == 0 :
104
+ continue
105
+
106
+ M_B_prime = M [B_prime , :].new ()
107
+ A3 = M_B_prime .dup (dtype = INT32 ).reduce_columnwise (plus )
108
+
109
+ threshold = int (0.95 * K )
110
+ A4 = A3 .select ('>=' , threshold ).new ()
111
+
112
+ if A4 .nvals == 0 :
113
+ continue
114
+
115
+ S_A4 = set (A4 .to_coo ()[0 ])
116
+
117
+ B_double_prime = [i for i in B_prime if S_A4 <= rows [i ]]
118
+
119
+ if len (B_double_prime ) < 5 :
120
+ continue
121
+
122
+ RIGHT_rows .append (A4 )
123
+
124
+ CORE = Vector (BOOL , size = n_rows )
125
+ for i in B_double_prime :
126
+ CORE [i ] = True
127
+ LEFT_columns .append (CORE )
128
+
129
+ num_buckets_remaining = len (LEFT_columns )
130
+ if num_buckets_remaining == 0 :
131
+ return Matrix (M .dtype , M .nrows , 0 ), Matrix (M .dtype , 0 , M .ncols )
132
+
133
+ LEFT = Matrix (bool , n_rows , num_buckets_remaining )
134
+ for idx , CORE in enumerate (LEFT_columns ):
135
+ LEFT [:, idx ] = CORE
136
+
137
+ RIGHT = Matrix (bool , num_buckets_remaining , n_cols )
138
+ for idx , A4 in enumerate (RIGHT_rows ):
139
+ RIGHT [idx , :] = A4
140
+
141
+ return LEFT , RIGHT
142
+
143
+ def column_based_decompose (M : Matrix ):
144
+ LEFT_T , RIGHT_T = row_based_decompose (M .T .new ())
145
+ return RIGHT_T .T .new (), LEFT_T .T .new ()
146
+
147
+ def decompose (M : Matrix ):
148
+ accumulated_LEFT = []
149
+ accumulated_RIGHT = []
150
+ iteration = 0
151
+
152
+ init_nvals = M .nvals
153
+ if init_nvals == 0 :
154
+ return Matrix (M .dtype , M .nrows , 0 ), Matrix (M .dtype , 0 , M .ncols )
155
+
156
+ while True :
157
+ iteration += 1
158
+ nvals_before = M .nvals
159
+
160
+ LEFT1 , RIGHT1 = row_based_decompose (M )
161
+
162
+ if LEFT1 .nvals != 0 :
163
+ M = M .dup (mask = ~ LEFT1 .mxm (RIGHT1 , op = graphblas .semiring .any_pair ).new (dtype = BOOL ).S )
164
+
165
+ LEFT2 , RIGHT2 = column_based_decompose (M )
166
+
167
+ if LEFT2 .nvals != 0 :
168
+ M = M .dup (mask = ~ LEFT2 .mxm (RIGHT2 , op = graphblas .semiring .any_pair ).new (dtype = BOOL ).S )
169
+
170
+ nvals_LEFT_RIGHT = LEFT1 .nvals + RIGHT1 .nvals + LEFT2 .nvals + RIGHT2 .nvals
171
+
172
+ nvals_after = M .nvals
173
+ delta_M = nvals_before - nvals_after
174
+
175
+ reduction_ratio = delta_M / nvals_before if nvals_before > 0 else 0
176
+ size_ratio = nvals_LEFT_RIGHT / delta_M if delta_M > 0 else float ('inf' )
177
+
178
+ accumulated_LEFT .extend ([LEFT1 , LEFT2 ])
179
+ accumulated_RIGHT .extend ([RIGHT1 , RIGHT2 ])
180
+
181
+ if reduction_ratio < 0.05 or size_ratio > 0.3 :
182
+ break
183
+
184
+ if M .nvals == 0 :
185
+ break
186
+
187
+ if not accumulated_LEFT or not accumulated_RIGHT :
188
+ return Matrix (BOOL , nrows = M .nrows , ncols = 0 ), Matrix (BOOL , nrows = 0 , ncols = M .ncols )
189
+
190
+ LEFT = stack ([accumulated_LEFT ])
191
+ RIGHT = stack ([[RIGHT ] for RIGHT in accumulated_RIGHT ])
192
+
193
+ return LEFT , RIGHT
194
+
195
+ def stack (matrix_grid : list [list [Matrix ]]) -> Matrix :
196
+ """
197
+ Stack a 2D list of matrices into a single larger matrix.
198
+ Vertically stacks matrices within each row of the list, and then horizontally stacks the results.
199
+
200
+ Parameters:
201
+ matrix_grid (list[list[Matrix]]): A 2D list of matrices to stack.
202
+
203
+ Returns:
204
+ Matrix: The stacked matrix.
205
+ """
206
+ if not matrix_grid or not matrix_grid [0 ]:
207
+ raise ValueError ("The matrix grid cannot be empty." )
208
+
209
+ num_cols = len (matrix_grid [0 ])
210
+ for row in matrix_grid :
211
+ if len (row ) != num_cols :
212
+ raise ValueError ("All rows in the matrix grid must have the same number of matrices." )
213
+
214
+ for row in matrix_grid :
215
+ row_height = row [0 ].nrows
216
+ for matrix in row :
217
+ if matrix .nrows != row_height :
218
+ raise ValueError ("All matrices in the same row must have the same number of rows." )
219
+
220
+ for col in range (num_cols ):
221
+ col_width = matrix_grid [0 ][col ].ncols
222
+ for row in matrix_grid :
223
+ if row [col ].ncols != col_width :
224
+ raise ValueError ("All matrices in the same column must have the same number of columns." )
225
+
226
+ combined_rows = []
227
+ combined_columns = []
228
+ combined_values = []
229
+
230
+ current_row_offset = 0
231
+
232
+ for row in matrix_grid :
233
+ current_col_offset = 0
234
+
235
+ for matrix in row :
236
+ M_I , M_J , M_V = matrix .to_coo ()
237
+
238
+ adjusted_rows = M_I + current_row_offset
239
+ adjusted_columns = M_J + current_col_offset
240
+
241
+ combined_rows .append (adjusted_rows )
242
+ combined_columns .append (adjusted_columns )
243
+ combined_values .append (M_V )
244
+
245
+ current_col_offset += matrix .ncols
246
+
247
+ current_row_offset += row [0 ].nrows
248
+
249
+ final_rows = np .concatenate (combined_rows )
250
+ final_columns = np .concatenate (combined_columns )
251
+ final_values = np .concatenate (combined_values )
252
+
253
+ total_rows = current_row_offset
254
+ total_columns = sum (matrix .ncols for matrix in matrix_grid [0 ])
255
+
256
+ return Matrix .from_coo (
257
+ rows = final_rows ,
258
+ columns = final_columns ,
259
+ values = final_values ,
260
+ dtype = matrix_grid [0 ][0 ].dtype ,
261
+ nrows = total_rows ,
262
+ ncols = total_columns ,
263
+ )
0 commit comments