@@ -23,15 +23,23 @@ def optimize_similarly(self, other: Matrix) -> "OptimizedMatrix":
2323 self .base .optimize_similarly (other )
2424 )
2525
26+ def to_mask (self ):
27+ if self .block_matrix_space .is_single_cell (self .shape ):
28+ print ("@@@@@@@" )
29+ print (self .shape )
30+ return self .base .to_mask ()
31+ else :
32+ return None
33+
2634
2735class CellBlockMatrix (BlockMatrix ):
2836 def __init__ (self , base : OptimizedMatrix , block_matrix_space : BlockMatrixSpace ):
2937 assert block_matrix_space .is_single_cell (base .shape )
3038 super ().__init__ (base , block_matrix_space )
3139
32- def mxm (self , other : Matrix , op : Semiring , swap_operands : bool = False ) -> Matrix :
40+ def mxm (self , other : Matrix , op : Semiring , mask : Matrix , swap_operands : bool = False ) -> Matrix :
3341 if self .block_matrix_space .is_single_cell (other .shape ):
34- return self .base .mxm (other , op , swap_operands = swap_operands )
42+ return self .base .mxm (other , op , mask , swap_operands = swap_operands )
3543 return self .base .mxm (
3644 self .block_matrix_space .hyper_rotate (
3745 other ,
@@ -40,6 +48,7 @@ def mxm(self, other: Matrix, op: Semiring, swap_operands: bool = False) -> Matri
4048 else BlockMatrixOrientation .HORIZONTAL
4149 ),
4250 op = op ,
51+ mask = mask ,
4352 swap_operands = swap_operands ,
4453 )
4554
@@ -83,20 +92,21 @@ def _force_init_orientation(
8392 self .discard_base_on_reformat = False
8493 return self .matrices [desired_orientation ]
8594
86- def mxm (self , other : Matrix , op : Semiring , swap_operands : bool = False ) -> Matrix :
95+ def mxm (self , other : Matrix , op : Semiring , mask : Matrix , swap_operands : bool = False ) -> Matrix :
8796 if self .block_matrix_space .is_single_cell (other .shape ):
8897 return self ._force_init_orientation (
8998 BlockMatrixOrientation .HORIZONTAL
9099 if swap_operands
91100 else BlockMatrixOrientation .VERTICAL
92- ).mxm (other , op , swap_operands = swap_operands )
101+ ).mxm (other , op , mask , swap_operands = swap_operands )
93102 return self ._force_init_orientation (
94103 BlockMatrixOrientation .VERTICAL
95104 if swap_operands
96105 else BlockMatrixOrientation .HORIZONTAL
97106 ).mxm (
98107 self .block_matrix_space .to_block_diag_matrix (other ),
99108 op = op ,
109+ mask = mask ,
100110 swap_operands = swap_operands
101111 )
102112
0 commit comments