Skip to content

Einsum Reduction #745

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 124 additions & 0 deletions benchmark_cfc.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
using BenchmarkTools;
using Finch;


## Imagine: Three {0-3}-D matrices
## Operations: + and *? Maybe more? (Ex: min)

function matmul(B, C)
return B*C
end

function matmul_einsum(B, C)
return @einsum A[i, j] += B[i, k] * C[k, j]
end

function matadd(B, C)
return B + C
end

function matadd_einsum(B, C)
return @einsum A[i, j] += B[i, j] + C[i, j]
end

function matvectmul(B, x)
return B * x
end

function matvectmul_einsum(B, x)
return @einsum y[i] += B[i, k] * x[k]
end

function tensormatmul(B, C)
return B * C
end

# relative to this guy

function tensormatmul_einsum(B, C)
return @einsum A[i, j, k] += B[i, j, l] * C[l, k]
end

# 4 kernels
# compare this guy

function tensormatmul_einsum_reshape(B, C)
B_p = reshape(B, (size(B)[1]*size(B)[2], size(B)[3]))
@einsum A_p[ij, k] += B_p[ij, l] * C[l, k]
A = reshape(A_p, (size(B)[1], size(B)[2], size(C)[2]))
return A
end

sparsity = 0.01
row = 1_000
col = 1_000
tube = 100
# TODO: different size + sparsity!
A_tensor = fsprand(row, col, tube, sparsity)
B = fsprand(row, tube, col, sparsity)
C = fsprand(col, row, sparsity)
x = fsprand(row, sparsity)

display(sparsity)
display(@benchmark tensormatmul_einsum(B, C))
display(@benchmark tensormatmul_einsum_reshape(B, C))
# display(@benchmark matmul(B, C))
# display(@benchmark matmul_einsum(B, C))
# display(@benchmark matadd(B, C))
# display(@benchmark matadd_einsum(B, C))
# display(@benchmark matvectmul(B, x))
# display(@benchmark matvectmul_einsum(B, x))
# display(@benchmark tensormatmul(A_tensor, C))
# display(@benchmark tensormatmul_einsum(A_tensor, C))

# sparsity = 0.5
# # TODO: different size + sparsity!
# A_tensor = fsprand(6, 6, 6, sparsity)
# B = fsprand(6, 6, sparsity)
# C = fsprand(6, 6, sparsity)
# x = fsprand(6, sparsity)
#
# display(sparsity)
# display(@benchmark matmul(B, C))
# display(@benchmark matmul_einsum(B, C))
# display(@benchmark matadd(B, C))
# display(@benchmark matadd_einsum(B, C))
# display(@benchmark matvectmul(B, x))
# display(@benchmark matvectmul_einsum(B, x))
# display(@benchmark tensormatmul(A_tensor, C))
# display(@benchmark tensormatmul_einsum(A_tensor, C))
#
# sparsity = 0.75
# # TODO: different size + sparsity!
# A_tensor = fsprand(6, 6, 6, sparsity)
# B = fsprand(6, 6, sparsity)
# C = fsprand(6, 6, sparsity)
# x = fsprand(6, sparsity)
#
# display(sparsity)
# display(@benchmark matmul(B, C))
# display(@benchmark matmul_einsum(B, C))
# display(@benchmark matadd(B, C))
# display(@benchmark matadd_einsum(B, C))
# display(@benchmark matvectmul(B, x))
# display(@benchmark matvectmul_einsum(B, x))
# display(@benchmark tensormatmul(A_tensor, C))
# display(@benchmark tensormatmul_einsum(A_tensor, C))
#
# sparsity = 1.0
# # TODO: different size + sparsity!
# A_tensor = fsprand(6, 6, 6, sparsity)
# B = fsprand(6, 6, sparsity)
# C = fsprand(6, 6, sparsity)
# x = fsprand(6, sparsity)
#
# display(sparsity)
# display(@benchmark matmul(B, C))
# display(@benchmark matmul_einsum(B, C))
# display(@benchmark matadd(B, C))
# display(@benchmark matadd_einsum(B, C))
# display(@benchmark matvectmul(B, x))
# display(@benchmark matvectmul_einsum(B, x))
# display(@benchmark tensormatmul(A_tensor, C))
# display(@benchmark tensormatmul_einsum(A_tensor, C))
#
272 changes: 272 additions & 0 deletions src/interface/simple_tensor.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@

# TODO: possibly have it inherit ( <: ) from something
struct SimpleTensor{T}
shape::Array{Int64}
data::T
dim::Array{Array{Int64}}
end

function combine_two_dim(A_old, B_old, shared_dim)

Check warning on line 9 in src/interface/simple_tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/simple_tensor.jl#L9

Added line #L9 was not covered by tests

A_shape = A_old.shape
B_shape = B_old.shape

Check warning on line 12 in src/interface/simple_tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/simple_tensor.jl#L11-L12

Added lines #L11 - L12 were not covered by tests

A_dim_idx = Set() # Array()
B_dim_idx = Set() # Array()

Check warning on line 15 in src/interface/simple_tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/simple_tensor.jl#L14-L15

Added lines #L14 - L15 were not covered by tests

A_new_dims = Array(collect(A_shape))
B_new_dims = Array(collect(B_shape))

Check warning on line 18 in src/interface/simple_tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/simple_tensor.jl#L17-L18

Added lines #L17 - L18 were not covered by tests

combined_dim = 1
offset = 0
for dim in shared_dim
push!(A_dim_idx, dim[1])
push!(B_dim_idx, dim[2])
@assert A_shape[dim[1]] == B_shape[dim[2]]

Check warning on line 25 in src/interface/simple_tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/simple_tensor.jl#L20-L25

Added lines #L20 - L25 were not covered by tests

deleteat!(A_new_dims, dim[1] - offset)
deleteat!(B_new_dims, dim[2] - offset)
offset += 1

Check warning on line 29 in src/interface/simple_tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/simple_tensor.jl#L27-L29

Added lines #L27 - L29 were not covered by tests

# new size
combined_dim *= A_shape[dim[1]]
end

Check warning on line 33 in src/interface/simple_tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/simple_tensor.jl#L32-L33

Added lines #L32 - L33 were not covered by tests

push!(A_new_dims, combined_dim)
push!(B_new_dims, combined_dim)

Check warning on line 36 in src/interface/simple_tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/simple_tensor.jl#L35-L36

Added lines #L35 - L36 were not covered by tests

_A_new_dims = Array(collect(setdiff(Set(1:length(A_shape)), Set(A_dim_idx))))
A_new_dims = [[dim] for dim in _A_new_dims]
push!(A_new_dims, Array(collect(A_dim_idx)))

Check warning on line 40 in src/interface/simple_tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/simple_tensor.jl#L38-L40

Added lines #L38 - L40 were not covered by tests

_B_new_dims = Array(collect(setdiff(Set(1:length(B_shape)), Set(B_dim_idx))))
B_new_dims = [[dim] for dim in _B_new_dims]
push!(B_new_dims, Array(collect(B_dim_idx)))

Check warning on line 44 in src/interface/simple_tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/simple_tensor.jl#L42-L44

Added lines #L42 - L44 were not covered by tests

A = reshape(A_old.data, combined_dim)
B = reshape(B_old.data, combined_dim)
A_new = SimpleTensor(A_shape, A, Array{Array{Int}}(collect(A_new_dims)))
B_new = SimpleTensor(B_shape, B, Array{Array{Int}}(B_new_dims))

Check warning on line 49 in src/interface/simple_tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/simple_tensor.jl#L46-L49

Added lines #L46 - L49 were not covered by tests

return A_new, B_new

Check warning on line 51 in src/interface/simple_tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/simple_tensor.jl#L51

Added line #L51 was not covered by tests
end

function combine_just_d(D_out, solution, unique_D)

Check warning on line 54 in src/interface/simple_tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/simple_tensor.jl#L54

Added line #L54 was not covered by tests
# basically a broadcast operation

end

function do_op_simple_tensor(A, B, C, D)

Check warning on line 59 in src/interface/simple_tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/simple_tensor.jl#L59

Added line #L59 was not covered by tests
# D = A B C
# FOR NOW: assume A is a mask

A_size = size(A)
A_n_idx = length(A_size)
B_size = size(B)
B_n_idx = length(B_size)
C_size = size(C)
C_n_idx = length(C_size)
D_size = size(D)
D_n_idx = length(D_size)

Check warning on line 70 in src/interface/simple_tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/simple_tensor.jl#L63-L70

Added lines #L63 - L70 were not covered by tests

# in all
ABCD_idx = Set()

Check warning on line 73 in src/interface/simple_tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/simple_tensor.jl#L73

Added line #L73 was not covered by tests

# in three
ABC_idx = Set()
ABD_idx = Set()
ACD_idx = Set()
BCD_idx = Set()

Check warning on line 79 in src/interface/simple_tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/simple_tensor.jl#L76-L79

Added lines #L76 - L79 were not covered by tests

# in two
AB_idx = Set()
AC_idx = Set()
AD_idx = Set()
BC_idx = Set()
BD_idx = Set()
CD_idx = Set()

Check warning on line 87 in src/interface/simple_tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/simple_tensor.jl#L82-L87

Added lines #L82 - L87 were not covered by tests

# single index
A_idx = Set()
B_idx = Set()
C_idx = Set()
D_idx = Set()

Check warning on line 93 in src/interface/simple_tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/simple_tensor.jl#L90-L93

Added lines #L90 - L93 were not covered by tests

# used indices
A_used_idx = Set()
B_used_idx = Set()
C_used_idx = Set()
D_used_idx = Set()

Check warning on line 99 in src/interface/simple_tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/simple_tensor.jl#L96-L99

Added lines #L96 - L99 were not covered by tests

# a TON of loops to populate these sets

# in four
for a_i in 1:A_n_idx
for b_i in 1:B_n_idx
for c_i in 1:C_n_idx
for d_i in 1:D_n_idx
if !(a_i in A_used_idx || b_i in B_used_idx || c_i in C_used_idx || d_i in D_used_idx) && (A_size[a_i] == B_size[b_i] && B_size[b_i] == C_size[c_i] && C_size[c_i] == D_size[d_i])
push!(ABCD_idx, (a_i, b_i, c_i, d_i))
push!(A_used_idx, a_i)
push!(B_used_idx, b_i)
push!(C_used_idx, c_i)
push!(D_used_idx, d_i)

Check warning on line 113 in src/interface/simple_tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/simple_tensor.jl#L104-L113

Added lines #L104 - L113 were not covered by tests
end
end
end
end
end

Check warning on line 118 in src/interface/simple_tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/simple_tensor.jl#L115-L118

Added lines #L115 - L118 were not covered by tests

# in three
for a_i in 1:A_n_idx
for b_i in 1:B_n_idx
for c_i in 1:C_n_idx
if !(a_i in A_used_idx || b_i in B_used_idx || c_i in C_used_idx) && (A_size[a_i] == B_size[b_i] && B_size[b_i] == C_size[c_i])
push!(ABC_idx, (a_i, b_i, c_i))
push!(A_used_idx, a_i)
push!(B_used_idx, b_i)
push!(C_used_idx, c_i)

Check warning on line 128 in src/interface/simple_tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/simple_tensor.jl#L121-L128

Added lines #L121 - L128 were not covered by tests
end
end
end
end

Check warning on line 132 in src/interface/simple_tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/simple_tensor.jl#L130-L132

Added lines #L130 - L132 were not covered by tests

for b_i in 1:B_n_idx
for c_i in 1:C_n_idx
for d_i in 1:D_n_idx
if !(b_i in B_used_idx || c_i in C_used_idx || d_i in D_used_idx) && (B_size[b_i] == C_size[c_i] && C_size[c_i] == D_size[d_i])
push!(BCD_idx, (b_i, c_i, d_i))
push!(B_used_idx, b_i)
push!(C_used_idx, c_i)
push!(D_used_idx, d_i)

Check warning on line 141 in src/interface/simple_tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/simple_tensor.jl#L134-L141

Added lines #L134 - L141 were not covered by tests
end
end
end
end

Check warning on line 145 in src/interface/simple_tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/simple_tensor.jl#L143-L145

Added lines #L143 - L145 were not covered by tests


for a_i in 1:A_n_idx
for b_i in 1:B_n_idx
for d_i in 1:D_n_idx
if !(a_i in A_used_idx || b_i in B_used_idx || d_i in D_used_idx) && (A_size[a_i] == B_size[b_i] && B_size[b_i] == D_size[d_i])
push!(ABD_idx, (a_i, b_i, d_i))
push!(A_used_idx, a_i)
push!(B_used_idx, b_i)
push!(D_used_idx, d_i)

Check warning on line 155 in src/interface/simple_tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/simple_tensor.jl#L148-L155

Added lines #L148 - L155 were not covered by tests
end
end
end
end

Check warning on line 159 in src/interface/simple_tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/simple_tensor.jl#L157-L159

Added lines #L157 - L159 were not covered by tests

for a_i in 1:A_n_idx
for c_i in 1:C_n_idx
for d_i in 1:D_n_idx
if !(a_i in A_used_idx || c_i in C_used_idx || d_i in D_used_idx) && (A_size[a_i] == C_size[c_i] && C_size[c_i] == D_size[d_i])
push!(ACD_idx, (a_i, c_i, d_i))
push!(A_used_idx, a_i)
push!(C_used_idx, c_i)
push!(D_used_idx, d_i)

Check warning on line 168 in src/interface/simple_tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/simple_tensor.jl#L161-L168

Added lines #L161 - L168 were not covered by tests
end
end
end
end

Check warning on line 172 in src/interface/simple_tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/simple_tensor.jl#L170-L172

Added lines #L170 - L172 were not covered by tests

# in two
for a_i in 1:A_n_idx
for b_i in 1:B_n_idx
if !(a_i in A_used_idx || b_i in B_used_idx) && (A_size[a_i] == B_size[b_i])
push!(AB_idx, (a_i, b_i))
push!(A_used_idx, a_i)
push!(B_used_idx, b_i)

Check warning on line 180 in src/interface/simple_tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/simple_tensor.jl#L175-L180

Added lines #L175 - L180 were not covered by tests
end
end
end

Check warning on line 183 in src/interface/simple_tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/simple_tensor.jl#L182-L183

Added lines #L182 - L183 were not covered by tests

for a_i in 1:A_n_idx
for c_i in 1:C_n_idx
if !(a_i in A_used_idx || c_i in C_used_idx) && (A_size[a_i] == C_size[c_i])
push!(AC_idx, (a_i, c_i))
push!(A_used_idx, a_i)
push!(C_used_idx, c_i)

Check warning on line 190 in src/interface/simple_tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/simple_tensor.jl#L185-L190

Added lines #L185 - L190 were not covered by tests
end
end
end

Check warning on line 193 in src/interface/simple_tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/simple_tensor.jl#L192-L193

Added lines #L192 - L193 were not covered by tests

for a_i in 1:A_n_idx
for d_i in 1:D_n_idx
if !(a_i in A_used_idx || d_i in D_used_idx) && (A_size[a_i] == D_size[d_i])
push!(AD_idx, (a_i, d_i))
push!(A_used_idx, a_i)
push!(D_used_idx, d_i)

Check warning on line 200 in src/interface/simple_tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/simple_tensor.jl#L195-L200

Added lines #L195 - L200 were not covered by tests
end
end
end

Check warning on line 203 in src/interface/simple_tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/simple_tensor.jl#L202-L203

Added lines #L202 - L203 were not covered by tests

for b_i in 1:B_n_idx
for c_i in 1:C_n_idx
if !(b_i in B_used_idx || c_i in C_used_idx) && (B_size[b_i] == C_size[c_i])
push!(BC_idx, (b_i, c_i))
push!(B_used_idx, b_i)
push!(C_used_idx, c_i)

Check warning on line 210 in src/interface/simple_tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/simple_tensor.jl#L205-L210

Added lines #L205 - L210 were not covered by tests
end
end
end

Check warning on line 213 in src/interface/simple_tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/simple_tensor.jl#L212-L213

Added lines #L212 - L213 were not covered by tests

for b_i in 1:B_n_idx
for d_i in 1:D_n_idx
if !(b_i in B_used_idx || d_i in D_used_idx) && (B_size[b_i] == D_size[d_i])
push!(BD_idx, (b_i, d_i))
push!(B_used_idx, b_i)
push!(D_used_idx, d_i)

Check warning on line 220 in src/interface/simple_tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/simple_tensor.jl#L215-L220

Added lines #L215 - L220 were not covered by tests
end
end
end

Check warning on line 223 in src/interface/simple_tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/simple_tensor.jl#L222-L223

Added lines #L222 - L223 were not covered by tests

for c_i in 1:C_n_idx
for d_i in 1:D_n_idx
if !(c_i in C_used_idx || d_i in D_used_idx) && (C_size[c_i] == D_size[d_i])
push!(CD_idx, (c_i, d_i))
push!(C_used_idx, c_i)
push!(D_used_idx, d_i)

Check warning on line 230 in src/interface/simple_tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/simple_tensor.jl#L225-L230

Added lines #L225 - L230 were not covered by tests
end
end
end

Check warning on line 233 in src/interface/simple_tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/simple_tensor.jl#L232-L233

Added lines #L232 - L233 were not covered by tests

# single index
for a_i in 1:A_n_idx
if !(a_i in A_used_idx)
push!(A_idx, a_i)
push!(A_used_idx, a_i)

Check warning on line 239 in src/interface/simple_tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/simple_tensor.jl#L236-L239

Added lines #L236 - L239 were not covered by tests
end
end

Check warning on line 241 in src/interface/simple_tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/simple_tensor.jl#L241

Added line #L241 was not covered by tests

for b_i in 1:B_n_idx
if !(b_i in B_used_idx)
push!(B_idx, b_i)
push!(B_used_idx, b_i)

Check warning on line 246 in src/interface/simple_tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/simple_tensor.jl#L243-L246

Added lines #L243 - L246 were not covered by tests
end
end

Check warning on line 248 in src/interface/simple_tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/simple_tensor.jl#L248

Added line #L248 was not covered by tests

for c_i in 1:C_n_idx
if !(c_i in C_used_idx)
push!(C_idx, c_i)
push!(C_used_idx, c_i)

Check warning on line 253 in src/interface/simple_tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/simple_tensor.jl#L250-L253

Added lines #L250 - L253 were not covered by tests
end
end

Check warning on line 255 in src/interface/simple_tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/simple_tensor.jl#L255

Added line #L255 was not covered by tests

for d_i in 1:D_n_idx
if !(d_i in D_used_idx)
push!(D_idx, d_i)
push!(D_used_idx, d_i)

Check warning on line 260 in src/interface/simple_tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/simple_tensor.jl#L257-L260

Added lines #L257 - L260 were not covered by tests
end
end

Check warning on line 262 in src/interface/simple_tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/simple_tensor.jl#L262

Added line #L262 was not covered by tests

#### TIME FOR THE ACTUAL COMPUTATION!

A_SimpleTensor = SimpleTensor(Array(collect(A_size)), A, Array{Array{Int64}, 1}([ [dim] for dim in 1:A_n_idx ]))
B_SimpleTensor = SimpleTensor(Array(collect(B_size)), B, Array{Array{Int64}, 1}([ [dim] for dim in 1:B_n_idx ]))
C_SimpleTensor = SimpleTensor(Array(collect(C_size)), C, Array{Array{Int64}, 1}([ [dim] for dim in 1:C_n_idx ]))
D_SimpleTensor = SimpleTensor(Array(collect(D_size)), D, Array{Array{Int64}, 1}([ [dim] for dim in 1:D_n_idx ]))

Check warning on line 269 in src/interface/simple_tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/interface/simple_tensor.jl#L266-L269

Added lines #L266 - L269 were not covered by tests


end
Loading