@@ -136,6 +136,43 @@ function is_elementwise_multiplication(arg1, arg2)
136136 return ! isempty (indices_in_common (arg1, arg2))
137137end
138138
139+
140+ function is_diag (arg:: BinaryOperation{Mult} )
141+ return is_diag (arg. arg1, arg. arg2)
142+ end
143+
144+ function is_diag (arg)
145+ return false
146+ end
147+
148+ function is_diag (arg1:: KrD , arg2:: TensorExpr )
149+ return is_diag (arg2, arg1)
150+ end
151+
152+ function is_diag (arg1:: KrD , arg2:: KrD )
153+ return false
154+ end
155+
156+ function is_diag (arg:: Union{Tensor,KrD,Zero} )
157+ return false
158+ end
159+
160+ function is_diag (arg1:: TensorExpr , arg2:: KrD )
161+ arg1_indices, arg2_indices = get_free_indices .((arg1, arg2))
162+
163+ return length (arg1_indices) == 1 && ! isempty (intersect (arg1_indices, arg2_indices))
164+ end
165+
166+ function is_diag (arg1:: Value , arg2:: Value )
167+ if isempty (get_free_indices (arg1))
168+ return is_diag (arg2)
169+ elseif isempty (get_free_indices (arg2))
170+ return is_diag (arg1)
171+ end
172+
173+ return is_diag (arg1) || is_diag (arg2)
174+ end
175+
139176function evaluate (:: Mult , arg1:: Tensor , arg2:: BinaryOperation{Mult} )
140177 return evaluate (Mult (), arg2, arg1)
141178end
@@ -147,7 +184,7 @@ function evaluate(::Mult, arg1::BinaryOperation{Mult}, arg2::Tensor)
147184 contracting_indices = eliminated_indices ([arg1_indices; arg2_indices])
148185
149186 if is_elementwise &&
150- is_diag2 (arg1) &&
187+ is_diag (arg1) &&
151188 ! isempty (contracting_indices) &&
152189 length (arg2_indices) == 1
153190 new_index = setdiff (arg1_indices, contracting_indices)
0 commit comments