Skip to content

Commit f09902e

Browse files
committed
Attempt to introduce masking
1 parent debb9d3 commit f09902e

File tree

8 files changed

+70
-18
lines changed

8 files changed

+70
-18
lines changed

cfpq_matrix/abstract_optimized_matrix_decorator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,6 @@ def dtype(self) -> DataType:
3131

3232
def to_unoptimized(self) -> Matrix:
3333
return self.base.to_unoptimized()
34+
35+
def to_mask(self) -> Matrix:
36+
return self.base.to_mask()

cfpq_matrix/block/block_matrix.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,23 @@ def optimize_similarly(self, other: Matrix) -> "OptimizedMatrix":
2323
self.base.optimize_similarly(other)
2424
)
2525

26+
def to_mask(self):
27+
if self.block_matrix_space.is_single_cell(self.shape):
28+
print ("@@@@@@@")
29+
print (self.shape)
30+
return self.base.to_mask()
31+
else:
32+
return None
33+
2634

2735
class CellBlockMatrix(BlockMatrix):
2836
def __init__(self, base: OptimizedMatrix, block_matrix_space: BlockMatrixSpace):
2937
assert block_matrix_space.is_single_cell(base.shape)
3038
super().__init__(base, block_matrix_space)
3139

32-
def mxm(self, other: Matrix, op: Semiring, swap_operands: bool = False) -> Matrix:
40+
def mxm(self, other: Matrix, op: Semiring, mask:Matrix, swap_operands: bool = False) -> Matrix:
3341
if self.block_matrix_space.is_single_cell(other.shape):
34-
return self.base.mxm(other, op, swap_operands=swap_operands)
42+
return self.base.mxm(other, op, mask, swap_operands=swap_operands)
3543
return self.base.mxm(
3644
self.block_matrix_space.hyper_rotate(
3745
other,
@@ -40,6 +48,7 @@ def mxm(self, other: Matrix, op: Semiring, swap_operands: bool = False) -> Matri
4048
else BlockMatrixOrientation.HORIZONTAL
4149
),
4250
op=op,
51+
mask=mask,
4352
swap_operands=swap_operands,
4453
)
4554

@@ -83,20 +92,21 @@ def _force_init_orientation(
8392
self.discard_base_on_reformat = False
8493
return self.matrices[desired_orientation]
8594

86-
def mxm(self, other: Matrix, op: Semiring, swap_operands: bool = False) -> Matrix:
95+
def mxm(self, other: Matrix, op: Semiring, mask:Matrix, swap_operands: bool = False) -> Matrix:
8796
if self.block_matrix_space.is_single_cell(other.shape):
8897
return self._force_init_orientation(
8998
BlockMatrixOrientation.HORIZONTAL
9099
if swap_operands
91100
else BlockMatrixOrientation.VERTICAL
92-
).mxm(other, op, swap_operands=swap_operands)
101+
).mxm(other, op, mask, swap_operands=swap_operands)
93102
return self._force_init_orientation(
94103
BlockMatrixOrientation.VERTICAL
95104
if swap_operands
96105
else BlockMatrixOrientation.HORIZONTAL
97106
).mxm(
98107
self.block_matrix_space.to_block_diag_matrix(other),
99108
op=op,
109+
mask=mask,
100110
swap_operands=swap_operands
101111
)
102112

cfpq_matrix/empty_optimized_matrix.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@ def __init__(self, base: OptimizedMatrix):
1414
def base(self) -> OptimizedMatrix:
1515
return self._base
1616

17-
def mxm(self, other: Matrix, op: Semiring, swap_operands: bool = False) -> Matrix:
17+
def mxm(self, other: Matrix, op: Semiring, mask:Matrix, swap_operands: bool = False) -> Matrix:
1818
if self.nvals == 0 or other.nvals == 0:
1919
if swap_operands:
2020
assert self.shape[0] == other.shape[1]
2121
return Matrix(self.dtype, self.shape[1], other.shape[0])
2222
assert self.shape[1] == other.shape[0]
2323
return Matrix(self.dtype, self.shape[0], other.shape[1])
24-
return self.base.mxm(other, op, swap_operands)
24+
return self.base.mxm(other, op, mask, swap_operands)
2525

2626
def rsub(self, other: Matrix, op: SubOp) -> Matrix:
2727
if self.nvals == 0 or other.nvals == 0:

cfpq_matrix/format_optimized_matrix.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@ def __new__(
3030
def base(self) -> OptimizedMatrix:
3131
return self._base
3232

33+
def to_mask(self) -> Matrix:
34+
print ("#####")
35+
#print (self.matrices[0].to_mask())
36+
return list(self.matrices.values())[0].to_mask()
37+
3338
def _force_init_format(self, desired_format: str) -> OptimizedMatrix:
3439
if desired_format not in self.matrices:
3540
base_matrix = self.base.to_unoptimized().dup()
@@ -42,16 +47,16 @@ def _force_init_format(self, desired_format: str) -> OptimizedMatrix:
4247
res = self.matrices[desired_format]
4348
return res
4449

45-
def mxm(self, other: Matrix, op: Semiring, swap_operands: bool = False) -> Matrix:
50+
def mxm(self, other: Matrix, op: Semiring, mask: Matrix, swap_operands: bool = False) -> Matrix:
4651
left_nvals = other.nvals if swap_operands else self.nvals
4752
right_nvals = self.nvals if swap_operands else other.nvals
4853
desired_format = "by_row" if left_nvals < right_nvals else "by_col"
4954

5055
if desired_format in self.matrices or other.nvals < self.nvals / self.reformat_threshold:
5156
other.ss.config["format"] = desired_format
5257
reformatted_self = self._force_init_format(desired_format)
53-
return reformatted_self.mxm(other, op, swap_operands=swap_operands)
54-
return self.base.mxm(other, op, swap_operands=swap_operands)
58+
return reformatted_self.mxm(other, op, mask, swap_operands=swap_operands)
59+
return self.base.mxm(other, op, mask, swap_operands=swap_operands)
5560

5661
def rsub(self, other: Matrix, op: SubOp) -> Matrix:
5762
return self.matrices.get(other.ss.config["format"], self.base).rsub(other, op)

cfpq_matrix/lazy_add_optimized_matrix.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ def base(self) -> OptimizedMatrix:
2323
def nvals(self) -> int:
2424
return sum(m.nvals for m in self.matrices)
2525

26+
def to_mask(self) -> Matrix:
27+
return self.matrices[0].to_mask()
28+
2629
def _map_and_fold(
2730
self,
2831
mapper,
@@ -53,10 +56,10 @@ def to_unoptimized(self) -> Matrix:
5356
self.force_combine_small_matrices(nvals_combine_threshold=float("inf"))
5457
return self.base.to_unoptimized()
5558

56-
def mxm(self, other: Matrix, op: Semiring, swap_operands: bool = False) -> Matrix:
59+
def mxm(self, other: Matrix, op: Semiring, mask:Matrix, swap_operands: bool = False) -> Matrix:
5760
self.update_monoid(op.monoid)
5861
return self._map_and_fold(
59-
mapper=lambda m: m.mxm(other, op=op, swap_operands=swap_operands),
62+
mapper=lambda m: m.mxm(other, op=op, mask=mask, swap_operands=swap_operands),
6063
combiner=lambda acc, cur: acc.ewise_add(cur, op=op.monoid).new(),
6164
nvals_combine_threshold=other.nvals
6265
)

cfpq_matrix/matrix_to_optimized_adapter.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,39 @@ def format(self) -> MatrixFormat:
3030
def dtype(self) -> DataType:
3131
return self.base.dtype
3232

33+
def to_mask(self) -> Matrix:
34+
return self.base
35+
3336
def to_unoptimized(self) -> Matrix:
3437
return self.base
3538

36-
def mxm(self, other: Matrix, op: Semiring, swap_operands: bool = False) -> Matrix:
37-
return (
38-
other.mxm(self.base, op)
39-
if swap_operands
40-
else self.base.mxm(other, op)
41-
).new(self.dtype)
39+
def mxm(self, other: Matrix, op: Semiring, mask: Matrix, swap_operands: bool = False) -> Matrix:
40+
# return (
41+
# other.mxm(self.base, op)
42+
# if swap_operands
43+
# else self.base.mxm(other, op)
44+
#).new(self.dtype)
45+
if swap_operands:
46+
if not mask is None:
47+
#print("Mask applied, swap operands")
48+
#result = Matrix(self.dtype,nrows=other.shape[0],ncols=self.shape[1])
49+
#mask_t = Matrix(mask.dtype, ncols=mask.ncols, nrows=mask.nrows)
50+
#mask_t << mask.T
51+
#result(~mask) << other.mxm(self.base, op)
52+
#result(~mask) << other.mxm(self.base, op).new(self.dtype)
53+
#return result
54+
return other.mxm(self.base, op).new(self.dtype)
55+
else: return other.mxm(self.base, op).new(self.dtype)
56+
else:
57+
if not mask is None:
58+
print("Mask applied")
59+
result = Matrix(self.dtype,nrows=self.shape[0],ncols=other.shape[1])
60+
#result(~mask) << self.base.mxm(other, op)
61+
result(~mask) << self.base.mxm(other, op).new(self.dtype)
62+
return result
63+
#return self.base.mxm(other, op).new(self.dtype)
64+
else: return self.base.mxm(other, op).new(self.dtype)
65+
4266

4367
def rsub(self, other: Matrix, op: SubOp) -> Matrix:
4468
return op(other, self.base)

cfpq_matrix/optimized_matrix.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def to_unoptimized(self) -> Matrix:
4747
pass
4848

4949
@abstractmethod
50-
def mxm(self, other: Matrix, op: Semiring, swap_operands: bool = False) -> Matrix:
50+
def mxm(self, other: Matrix, op: Semiring, mask: Matrix, swap_operands: bool = False) -> Matrix:
5151
pass
5252

5353
@abstractmethod

cfpq_model/label_decomposed_graph.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,9 +262,16 @@ def mxm(
262262
if swap_operands:
263263
rhs1, rhs2 = rhs2, rhs1
264264
if rhs1 in self.matrices and rhs2 in other.matrices:
265+
#if lhs in self.matrices:
266+
# print ("!!! Mask: ")
267+
# print(self.matrices[lhs])
268+
# print("-------------------")
269+
# print(self.matrices[rhs1].shape)
270+
# print(other.matrices[rhs2].shape)
265271
mxm = self.matrices[rhs1].mxm(
266272
other.matrices[rhs2],
267273
swap_operands=swap_operands,
274+
mask=(self.matrices[lhs].to_mask() if lhs in self.matrices and other.matrices[rhs2].shape == self.matrices[rhs1].shape else None), #Matrix(dtype=self.matrices[rhs1].dtype, nrows=self.matrices[rhs1].nrows, ncols = self.matrices[rhs2].ncols)),
268275
op=op,
269276
)
270277
accum.iadd_by_symbol(lhs, mxm, op.monoid)

0 commit comments

Comments
 (0)