@@ -125,8 +125,45 @@ function simplify(::Mult, arg1::Tensor, arg2::BinaryOperation{Mult})
125125 return simplify (Mult (), arg2, arg1)
126126end
127127
128+ function can_apply (l:: BinaryOperation{Mult} , r:: KrD )
129+ return can_contract (l. arg1, r) || can_contract (l. arg2, r)
130+ end
131+
132+ function can_apply (l:: KrD , r:: BinaryOperation{Mult} )
133+ return can_contract (l, r. arg1) || can_contract (l, r. arg2)
134+ end
135+
136+ function can_apply (l:: BinaryOperation{Mult} , r:: BinaryOperation{Mult} )
137+ return can_apply (l, r. arg1) || can_apply (l, r. arg2)
138+ end
139+
140+ function can_apply (l:: BinaryOperation{Mult} , r:: Tensor )
141+ return can_contract (l. arg1, r) || can_contract (l. arg2, r)
142+ end
143+
144+ function can_apply (l:: Tensor , r:: BinaryOperation{Mult} )
145+ return can_contract (l, r. arg1) || can_contract (l, r. arg2)
146+ end
147+
148+ function can_apply (l:: Tensor , r:: KrD )
149+ return can_contract (l, r)
150+ end
151+
152+ function can_apply (l:: KrD , r:: Tensor )
153+ return can_contract (l, r)
154+ end
155+
156+ function can_apply (l:: KrD , r:: KrD )
157+ return can_contract (l, r)
158+ end
159+
160+ function can_apply (l, r)
161+ return false
162+ end
163+
128164function simplify (:: Mult , arg1:: BinaryOperation{Mult} , arg2:: Tensor )
129165 op = BinaryOperation {Mult} (arg1, arg2)
166+ target_indices = unique (get_free_indices (op))
130167
131168 if is_diagm (arg1) &&
132169 ! is_elementwise_multiplication (arg1, arg2) &&
@@ -135,7 +172,6 @@ function simplify(::Mult, arg1::BinaryOperation{Mult}, arg2::Tensor)
135172
136173 @assert ! isnothing (d)
137174
138- target_indices = eliminate_indices (vcat (get_free_indices (arg1), get_indices (arg2)))
139175 factors = collect_factors (arg1)
140176 vector_factors = filter (f -> f != d, factors)
141177 reshaped = []
@@ -189,6 +225,32 @@ function simplify(::Mult, arg1::BinaryOperation{Mult}, arg2::Tensor)
189225 return to_binary_operation (Mult (), reshaped)
190226 end
191227
228+ if length (get_free_indices (arg1)) > 2 && length (get_free_indices (arg2)) == 1
229+ if can_apply (arg1. arg1, arg2) &&
230+ can_apply (arg1. arg2, arg2) &&
231+ arg1. arg1 isa KrD &&
232+ arg1. arg2 isa KrD
233+ r_free_indices = get_free_indices (arg2)
234+
235+ new_r = update_index (
236+ arg2,
237+ only (r_free_indices),
238+ first (target_indices);
239+ allow_shape_change = true ,
240+ )
241+
242+ eliminated = eliminated_indices ([get_free_indices (arg1); r_free_indices])
243+ new_l = update_index (
244+ arg1. arg2,
245+ first (eliminated),
246+ first (target_indices);
247+ allow_shape_change = true ,
248+ )
249+
250+ return BinaryOperation {Mult} (new_l, new_r)
251+ end
252+ end
253+
192254 return op
193255end
194256
0 commit comments