Skip to content

Commit b001e2e

Browse files
committed
Don't simplify perfect element-wise products
1 parent 70a28d3 commit b001e2e

2 files changed

Lines changed: 10 additions & 3 deletions

File tree

src/forward.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,11 @@ function is_elementwise_multiplication(arg1, arg2)
125125
return !isempty(indices_in_common(arg1, arg2))
126126
end
127127

128+
function is_all_elementwise(arg1, arg2)
129+
num_common_ids = length(indices_in_common(arg1, arg2))
130+
return num_common_ids == length(get_free_indices(arg1)) &&
131+
num_common_ids == length(get_free_indices(arg2))
132+
end
128133

129134
function is_diagm(arg::BinaryOperation{Mult})
130135
return is_diagm(arg.arg1, arg.arg2)
@@ -182,11 +187,11 @@ function evaluate(::Mult, arg1::BinaryOperation{Mult}, arg2::Union{Variable,Lite
182187
end
183188

184189
is_arg1_elementwise = is_elementwise_multiplication(arg1.arg1, arg1.arg2)
185-
is_all_elementwise =
190+
are_both_elwise =
186191
is_elementwise_multiplication(arg1.arg1, arg2) &&
187192
is_elementwise_multiplication(arg1.arg2, arg2)
188193

189-
if is_arg1_elementwise || is_all_elementwise
194+
if is_arg1_elementwise || are_both_elwise
190195
return BinaryOperation{Mult}(evaluate(arg1), arg2)
191196
end
192197

src/simplify.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,9 @@ function simplify(::Mult, arg1::BinaryOperation{Mult}, arg2::Literal)
9494
return simplify(reshaped)
9595
end
9696

97-
if can_contract(arg1, arg2) && length(arg2_free_ids) == 1
97+
if can_contract(arg1, arg2) &&
98+
length(arg2_free_ids) == 1 &&
99+
!is_all_elementwise(arg1.arg1, arg1.arg2)
98100
elwise_ids = elementwise_indices(arg1.arg1, arg1.arg2)
99101

100102
target_idx = only(arg2_free_ids)

0 commit comments

Comments
 (0)