1- from typing import Any
1+ import random
2+ from collections import defaultdict
3+ from typing import Any , Tuple
24
35import graphblas
4- from graphblas .core .dtypes import DataType
6+ import numpy as np
7+ from graphblas .binary import plus
8+ from graphblas .core .dtypes import DataType , BOOL , INT32
59from graphblas .core .matrix import Matrix
610from graphblas .core .vector import Vector
711
@@ -22,3 +26,238 @@ def identity_matrix(one: Any, dtype: DataType, size: int) -> Matrix:
2226 size = size ,
2327 dtype = dtype
2428 ).diag ()
29+
30+ def expand_matrix (matrix : Matrix , new_shape : Tuple [int , int ]) -> Matrix :
31+ (rows , columns , values ) = matrix .to_coo ()
32+ return Matrix .from_coo (rows , columns , values , dtype = matrix .dtype , nrows = new_shape [0 ], ncols = new_shape [1 ])
33+
34+ def row_based_decompose (M : Matrix ):
35+ """
36+ Decomposes a sparse boolean matrix M into LEFT, RIGHT, and M' such that M = LEFT * RIGHT + M'.
37+
38+ Parameters:
39+ M (gb.Matrix): Input sparse boolean matrix.
40+
41+ Returns:
42+ LEFT (gb.Matrix): Left factor matrix.
43+ RIGHT (gb.Matrix): Right factor matrix.
44+ M_prime (gb.Matrix): Remainder matrix after decomposition.
45+ """
46+ n_rows , n_cols = M .shape
47+
48+ I , J , V = M .to_coo ()
49+
50+ rows = defaultdict (set )
51+ for i , j in zip (I , J ):
52+ rows [i ].add (j )
53+
54+ p = 2147483647
55+ num_hashes = 5 # TODO 2 or 3 is probably better for real world data
56+ hash_funcs = []
57+ for _ in range (num_hashes ):
58+ a = random .randint (1 , p - 1 )
59+ b = random .randint (0 , p - 1 )
60+ hash_funcs .append ((a , b ))
61+
62+ minhashes = dict ()
63+
64+ for i , S_i in rows .items ():
65+ minhash_values = []
66+ if len (S_i ) < 5 :
67+ continue
68+ for a , b in hash_funcs :
69+ min_hash = min (((a * x + b ) % p ) for x in S_i )
70+ minhash_values .append (min_hash )
71+ minhashes [i ] = tuple (minhash_values )
72+
73+ master_hashes = dict ()
74+ for i , minhash_values in minhashes .items ():
75+ master_hash = hash (minhash_values )
76+ master_hashes [i ] = master_hash
77+
78+ buckets = defaultdict (list )
79+ for i , master_hash in master_hashes .items ():
80+ buckets [master_hash ].append (i )
81+
82+ buckets = {h : idxs for h , idxs in buckets .items () if len (idxs ) >= 5 }
83+
84+ LEFT_columns = []
85+ RIGHT_rows = []
86+
87+ for h , B in buckets .items ():
88+ N = len (B )
89+ M_B : Matrix = M [B , :].new ()
90+ A1 = M_B .dup (dtype = INT32 ).reduce_columnwise (plus ).new ()
91+
92+ threshold = int (0.95 * N )
93+ A2 : Vector = A1 .select ('>=' , threshold ).new ()
94+
95+ if A2 .nvals == 0 :
96+ continue
97+
98+ S_A2 = set (A2 .to_coo ()[0 ])
99+
100+ B_prime = [i for i in B if S_A2 <= rows [i ]]
101+
102+ K = len (B_prime )
103+ if K == 0 :
104+ continue
105+
106+ M_B_prime = M [B_prime , :].new ()
107+ A3 = M_B_prime .dup (dtype = INT32 ).reduce_columnwise (plus )
108+
109+ threshold = int (0.95 * K )
110+ A4 = A3 .select ('>=' , threshold ).new ()
111+
112+ if A4 .nvals == 0 :
113+ continue
114+
115+ S_A4 = set (A4 .to_coo ()[0 ])
116+
117+ B_double_prime = [i for i in B_prime if S_A4 <= rows [i ]]
118+
119+ if len (B_double_prime ) < 5 :
120+ continue
121+
122+ RIGHT_rows .append (A4 )
123+
124+ CORE = Vector (BOOL , size = n_rows )
125+ for i in B_double_prime :
126+ CORE [i ] = True
127+ LEFT_columns .append (CORE )
128+
129+ num_buckets_remaining = len (LEFT_columns )
130+ if num_buckets_remaining == 0 :
131+ return Matrix (M .dtype , M .nrows , 0 ), Matrix (M .dtype , 0 , M .ncols )
132+
133+ LEFT = Matrix (bool , n_rows , num_buckets_remaining )
134+ for idx , CORE in enumerate (LEFT_columns ):
135+ LEFT [:, idx ] = CORE
136+
137+ RIGHT = Matrix (bool , num_buckets_remaining , n_cols )
138+ for idx , A4 in enumerate (RIGHT_rows ):
139+ RIGHT [idx , :] = A4
140+
141+ return LEFT , RIGHT
142+
143+ def column_based_decompose (M : Matrix ):
144+ LEFT_T , RIGHT_T = row_based_decompose (M .T .new ())
145+ return RIGHT_T .T .new (), LEFT_T .T .new ()
146+
147+ def decompose (M : Matrix ):
148+ accumulated_LEFT = []
149+ accumulated_RIGHT = []
150+ iteration = 0
151+
152+ init_nvals = M .nvals
153+ if init_nvals == 0 :
154+ return Matrix (M .dtype , M .nrows , 0 ), Matrix (M .dtype , 0 , M .ncols )
155+
156+ while True :
157+ iteration += 1
158+ nvals_before = M .nvals
159+
160+ LEFT1 , RIGHT1 = row_based_decompose (M )
161+
162+ if LEFT1 .nvals != 0 :
163+ M = M .dup (mask = ~ LEFT1 .mxm (RIGHT1 , op = graphblas .semiring .any_pair ).new (dtype = BOOL ).S )
164+
165+ LEFT2 , RIGHT2 = column_based_decompose (M )
166+
167+ if LEFT2 .nvals != 0 :
168+ M = M .dup (mask = ~ LEFT2 .mxm (RIGHT2 , op = graphblas .semiring .any_pair ).new (dtype = BOOL ).S )
169+
170+ nvals_LEFT_RIGHT = LEFT1 .nvals + RIGHT1 .nvals + LEFT2 .nvals + RIGHT2 .nvals
171+
172+ nvals_after = M .nvals
173+ delta_M = nvals_before - nvals_after
174+
175+ reduction_ratio = delta_M / nvals_before if nvals_before > 0 else 0
176+ size_ratio = nvals_LEFT_RIGHT / delta_M if delta_M > 0 else float ('inf' )
177+
178+ accumulated_LEFT .extend ([LEFT1 , LEFT2 ])
179+ accumulated_RIGHT .extend ([RIGHT1 , RIGHT2 ])
180+
181+ if reduction_ratio < 0.05 or size_ratio > 0.3 :
182+ break
183+
184+ if M .nvals == 0 :
185+ break
186+
187+ if not accumulated_LEFT or not accumulated_RIGHT :
188+ return Matrix (BOOL , nrows = M .nrows , ncols = 0 ), Matrix (BOOL , nrows = 0 , ncols = M .ncols )
189+
190+ LEFT = stack ([accumulated_LEFT ])
191+ RIGHT = stack ([[RIGHT ] for RIGHT in accumulated_RIGHT ])
192+
193+ return LEFT , RIGHT
194+
195+ def stack (matrix_grid : list [list [Matrix ]]) -> Matrix :
196+ """
197+ Stack a 2D list of matrices into a single larger matrix.
198+ Vertically stacks matrices within each row of the list, and then horizontally stacks the results.
199+
200+ Parameters:
201+ matrix_grid (list[list[Matrix]]): A 2D list of matrices to stack.
202+
203+ Returns:
204+ Matrix: The stacked matrix.
205+ """
206+ if not matrix_grid or not matrix_grid [0 ]:
207+ raise ValueError ("The matrix grid cannot be empty." )
208+
209+ num_cols = len (matrix_grid [0 ])
210+ for row in matrix_grid :
211+ if len (row ) != num_cols :
212+ raise ValueError ("All rows in the matrix grid must have the same number of matrices." )
213+
214+ for row in matrix_grid :
215+ row_height = row [0 ].nrows
216+ for matrix in row :
217+ if matrix .nrows != row_height :
218+ raise ValueError ("All matrices in the same row must have the same number of rows." )
219+
220+ for col in range (num_cols ):
221+ col_width = matrix_grid [0 ][col ].ncols
222+ for row in matrix_grid :
223+ if row [col ].ncols != col_width :
224+ raise ValueError ("All matrices in the same column must have the same number of columns." )
225+
226+ combined_rows = []
227+ combined_columns = []
228+ combined_values = []
229+
230+ current_row_offset = 0
231+
232+ for row in matrix_grid :
233+ current_col_offset = 0
234+
235+ for matrix in row :
236+ M_I , M_J , M_V = matrix .to_coo ()
237+
238+ adjusted_rows = M_I + current_row_offset
239+ adjusted_columns = M_J + current_col_offset
240+
241+ combined_rows .append (adjusted_rows )
242+ combined_columns .append (adjusted_columns )
243+ combined_values .append (M_V )
244+
245+ current_col_offset += matrix .ncols
246+
247+ current_row_offset += row [0 ].nrows
248+
249+ final_rows = np .concatenate (combined_rows )
250+ final_columns = np .concatenate (combined_columns )
251+ final_values = np .concatenate (combined_values )
252+
253+ total_rows = current_row_offset
254+ total_columns = sum (matrix .ncols for matrix in matrix_grid [0 ])
255+
256+ return Matrix .from_coo (
257+ rows = final_rows ,
258+ columns = final_columns ,
259+ values = final_values ,
260+ dtype = matrix_grid [0 ][0 ].dtype ,
261+ nrows = total_rows ,
262+ ncols = total_columns ,
263+ )
0 commit comments