Skip to content

Commit 8ea52f6

Browse files
committed
Replace is_diag with is_elementwise_multiplication
1 parent d58ebab commit 8ea52f6

1 file changed

Lines changed: 2 additions & 24 deletions

File tree

src/forward.jl

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ function evaluate(::Mult, arg1::BinaryOperation{Mult}, arg2::Tensor)
147147
contracting_indices = eliminated_indices([arg1_indices; arg2_indices])
148148

149149
if is_elementwise &&
150-
is_diag(arg1) &&
150+
is_diag2(arg1) &&
151151
!isempty(contracting_indices) &&
152152
length(arg2_indices) == 1
153153
new_index = setdiff(arg1_indices, contracting_indices)
@@ -337,28 +337,6 @@ function evaluate(::Mult, arg1::KrD, arg2::UnaryOp) where {UnaryOp<:UnaryOperati
337337
return BinaryOperation{Mult}(evaluate(arg2), evaluate(arg1))
338338
end
339339

340-
function is_diag(arg1::KrD, arg2::TensorExpr)
341-
return is_diag(arg2, arg1)
342-
end
343-
344-
function is_diag(arg1::TensorExpr, arg2::KrD)
345-
arg1_indices, arg2_indices = get_free_indices.((arg1, arg2))
346-
347-
return length(arg1_indices) == 1 && !isempty(intersect(arg1_indices, arg2_indices))
348-
end
349-
350-
function is_diag(arg1::KrD, arg2::KrD)
351-
return false
352-
end
353-
354-
function is_diag(arg::BinaryOperation{Mult})
355-
return is_diag(arg.arg1, arg.arg2)
356-
end
357-
358-
function is_diag(arg1, arg2)
359-
return false
360-
end
361-
362340
function evaluate(::Mult, arg1::Tensor, arg2::KrD)
363341
return _multiply_with_krd(arg1, arg2)
364342
end
@@ -379,7 +357,7 @@ function _multiply_with_krd(arg1::Union{Tensor,KrD}, arg2::KrD)
379357
return BinaryOperation{Mult}(arg1, arg2)
380358
end
381359

382-
if is_diag(arg1, arg2)
360+
if is_elementwise_multiplication(arg1, arg2)
383361
return BinaryOperation{Mult}(arg1, arg2)
384362
end
385363

0 commit comments

Comments
 (0)