Skip to content

Commit 49bb65e

Browse files
committed
Attempt to rotate AST when node order > 2
1 parent b0835d2 commit 49bb65e

1 file changed

Lines changed: 84 additions & 0 deletions

File tree

src/simplify.jl

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,53 @@ function can_apply(l, r)
161161
return false
162162
end
163163

164+
function sift_down(arg::Tensor, tree::BinaryOperation{Mult})
165+
arg_ids = get_free_indices(arg)
166+
167+
l_ids = get_free_indices(tree.arg1)
168+
r_ids = get_free_indices(tree.arg2)
169+
170+
# In this case, 'arg' and 'tree' have the same modes, then we can do element-wise multiplication
171+
if isempty(setdiff(arg_ids, l_ids)) && isempty(setdiff(l_ids, arg_ids))
172+
new_tree = BinaryOperation{Mult}(BinaryOperation{Mult}(arg, tree.arg1), tree.arg2)
173+
@assert length(get_free_indices(new_tree)) ==
174+
length(get_free_indices(BinaryOperation{Mult}(tree, arg)))
175+
176+
return new_tree
177+
elseif isempty(setdiff(arg_ids, r_ids)) && isempty(setdiff(r_ids, arg_ids))
178+
new_tree = BinaryOperation{Mult}(tree.arg1, BinaryOperation{Mult}(tree.arg2, arg))
179+
@assert length(get_free_indices(new_tree)) ==
180+
length(get_free_indices(BinaryOperation{Mult}(tree, arg)))
181+
182+
return new_tree
183+
end
184+
185+
# Otherwise, recurse
186+
if isempty(setdiff(arg_ids, l_ids))
187+
new_tree = BinaryOperation{Mult}(sift_down(arg, tree.arg1), tree.arg2)
188+
@assert length(get_free_indices(new_tree)) ==
189+
length(get_free_indices(BinaryOperation{Mult}(tree, arg)))
190+
191+
return new_tree
192+
elseif isempty(setdiff(arg_ids, r_ids))
193+
new_tree = BinaryOperation{Mult}(tree.arg1, sift_down(arg, tree.arg2))
194+
@assert length(get_free_indices(new_tree)) ==
195+
length(get_free_indices(BinaryOperation{Mult}(tree, arg)))
196+
197+
return new_tree
198+
end
199+
200+
return nothing
201+
end
202+
203+
function flip_indices(ids, arg)
204+
for i ids
205+
arg = update_index(arg, i, flip(i); allow_shape_change = true)
206+
end
207+
208+
return arg
209+
end
210+
164211
function simplify(::Mult, arg1::BinaryOperation{Mult}, arg2::Tensor)
165212
op = BinaryOperation{Mult}(arg1, arg2)
166213
target_indices = unique(get_free_indices(op))
@@ -251,6 +298,43 @@ function simplify(::Mult, arg1::BinaryOperation{Mult}, arg2::Tensor)
251298
end
252299
end
253300

301+
arg1_free_indices = get_free_indices(arg1)
302+
arg2_free_indices = get_free_indices(arg2)
303+
304+
if length(arg1_free_indices) > 2 && length(target_indices) <= 2
305+
new_tree = sift_down(
306+
arg2,
307+
BinaryOperation{Mult}(
308+
flip_indices(get_free_indices(arg2'), arg1.arg1),
309+
arg1.arg2,
310+
),
311+
)
312+
313+
if isnothing(new_tree)
314+
new_tree = sift_down(
315+
arg2,
316+
BinaryOperation{Mult}(
317+
arg1.arg1,
318+
flip_indices(get_free_indices(arg2'), arg1.arg2),
319+
),
320+
)
321+
end
322+
323+
if !isnothing(new_tree)
324+
@assert length(get_free_indices(new_tree)) <= 2
325+
return new_tree
326+
end
327+
end
328+
329+
if length(arg2_free_indices) > 2 && length(target_indices) <= 2
330+
new_tree = sift_down(arg1, arg2)
331+
332+
if !isnothing(new_tree)
333+
@assert length(get_free_indices(new_tree)) <= 2
334+
return new_tree
335+
end
336+
end
337+
254338
return op
255339
end
256340

0 commit comments

Comments
 (0)