@@ -161,6 +161,53 @@ function can_apply(l, r)
161161 return false
162162end
163163
164+ function sift_down (arg:: Tensor , tree:: BinaryOperation{Mult} )
165+ arg_ids = get_free_indices (arg)
166+
167+ l_ids = get_free_indices (tree. arg1)
168+ r_ids = get_free_indices (tree. arg2)
169+
170+ # In this case, 'arg' and 'tree' have the same modes, then we can do element-wise multiplication
171+ if isempty (setdiff (arg_ids, l_ids)) && isempty (setdiff (l_ids, arg_ids))
172+ new_tree = BinaryOperation {Mult} (BinaryOperation {Mult} (arg, tree. arg1), tree. arg2)
173+ @assert length (get_free_indices (new_tree)) ==
174+ length (get_free_indices (BinaryOperation {Mult} (tree, arg)))
175+
176+ return new_tree
177+ elseif isempty (setdiff (arg_ids, r_ids)) && isempty (setdiff (r_ids, arg_ids))
178+ new_tree = BinaryOperation {Mult} (tree. arg1, BinaryOperation {Mult} (tree. arg2, arg))
179+ @assert length (get_free_indices (new_tree)) ==
180+ length (get_free_indices (BinaryOperation {Mult} (tree, arg)))
181+
182+ return new_tree
183+ end
184+
185+ # Otherwise, recurse
186+ if isempty (setdiff (arg_ids, l_ids))
187+ new_tree = BinaryOperation {Mult} (sift_down (arg, tree. arg1), tree. arg2)
188+ @assert length (get_free_indices (new_tree)) ==
189+ length (get_free_indices (BinaryOperation {Mult} (tree, arg)))
190+
191+ return new_tree
192+ elseif isempty (setdiff (arg_ids, r_ids))
193+ new_tree = BinaryOperation {Mult} (tree. arg1, sift_down (arg, tree. arg2))
194+ @assert length (get_free_indices (new_tree)) ==
195+ length (get_free_indices (BinaryOperation {Mult} (tree, arg)))
196+
197+ return new_tree
198+ end
199+
200+ return nothing
201+ end
202+
203+ function flip_indices (ids, arg)
204+ for i ∈ ids
205+ arg = update_index (arg, i, flip (i); allow_shape_change = true )
206+ end
207+
208+ return arg
209+ end
210+
164211function simplify (:: Mult , arg1:: BinaryOperation{Mult} , arg2:: Tensor )
165212 op = BinaryOperation {Mult} (arg1, arg2)
166213 target_indices = unique (get_free_indices (op))
@@ -251,6 +298,43 @@ function simplify(::Mult, arg1::BinaryOperation{Mult}, arg2::Tensor)
251298 end
252299 end
253300
301+ arg1_free_indices = get_free_indices (arg1)
302+ arg2_free_indices = get_free_indices (arg2)
303+
304+ if length (arg1_free_indices) > 2 && length (target_indices) <= 2
305+ new_tree = sift_down (
306+ arg2,
307+ BinaryOperation {Mult} (
308+ flip_indices (get_free_indices (arg2' ), arg1. arg1),
309+ arg1. arg2,
310+ ),
311+ )
312+
313+ if isnothing (new_tree)
314+ new_tree = sift_down (
315+ arg2,
316+ BinaryOperation {Mult} (
317+ arg1. arg1,
318+ flip_indices (get_free_indices (arg2' ), arg1. arg2),
319+ ),
320+ )
321+ end
322+
323+ if ! isnothing (new_tree)
324+ @assert length (get_free_indices (new_tree)) <= 2
325+ return new_tree
326+ end
327+ end
328+
329+ if length (arg2_free_indices) > 2 && length (target_indices) <= 2
330+ new_tree = sift_down (arg1, arg2)
331+
332+ if ! isnothing (new_tree)
333+ @assert length (get_free_indices (new_tree)) <= 2
334+ return new_tree
335+ end
336+ end
337+
254338 return op
255339end
256340
0 commit comments