Skip to content

Commit 6b99bfd

Browse files
committed
Make 'to_standard' more lenient in terms of output shape
1 parent 7e889fa commit 6b99bfd

2 files changed

Lines changed: 22 additions & 19 deletions

File tree

src/forward.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,15 @@ function evaluate(::Mult, arg1::BinaryOperation{Mult}, arg2::BinaryOperation{Mul
187187
new_args = []
188188

189189
available1 = Any[arg1.arg1; arg1.arg2]
190+
if is_elementwise_multiplication(arg1.arg1, arg1.arg2)
191+
available1 = Any[arg1]
192+
end
193+
190194
available2 = Any[arg2.arg1; arg2.arg2]
195+
if is_elementwise_multiplication(arg2.arg1, arg2.arg2)
196+
available2 = Any[arg2]
197+
end
198+
191199

192200
for i eachindex(available1)
193201
if isnothing(available1[i])
@@ -225,7 +233,7 @@ function evaluate(::Mult, arg1::BinaryOperation{Mult}, arg2::BinaryOperation{Mul
225233
if isnothing(new_arg)
226234
return args[1]
227235
else
228-
return evaluate(BinaryOperation{Mult}(new_arg, args[1]))
236+
return BinaryOperation{Mult}(new_arg, args[1])
229237
end
230238
end
231239

src/std.jl

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,7 @@ end
629629
function to_standard(arg::BinaryOperation{Mult})
630630
target_indices = unique(get_free_indices(arg))
631631
target_len = length(target_indices)
632+
is_scalar = isempty(target_indices)
632633

633634
if length(target_indices) > 2
634635
throw_not_std(arg)
@@ -637,24 +638,18 @@ function to_standard(arg::BinaryOperation{Mult})
637638
l = to_standard(arg.arg1)
638639
r = to_standard(arg.arg2)
639640

640-
attempt = BinaryOperation{Mult}(l, r)
641-
if get_free_indices(attempt) == target_indices
642-
return attempt
643-
end
644-
645-
attempt = BinaryOperation{Mult}(adjoint(l), r)
646-
if get_free_indices(attempt) == target_indices
647-
return attempt
648-
end
649-
650-
attempt = BinaryOperation{Mult}(l, adjoint(r))
651-
if get_free_indices(attempt) == target_indices
652-
return attempt
653-
end
654-
655-
attempt = BinaryOperation{Mult}(adjoint(l), adjoint(r))
656-
if get_free_indices(attempt) == target_indices
657-
return attempt
641+
attempts = (
642+
BinaryOperation{Mult}(l, r),
643+
BinaryOperation{Mult}(adjoint(l), r),
644+
BinaryOperation{Mult}(l, adjoint(r)),
645+
BinaryOperation{Mult}(adjoint(l), adjoint(r)),
646+
)
647+
648+
for attempt attempts
649+
if length(get_free_indices(attempt)) == target_len &&
650+
(is_scalar || last(get_free_indices(attempt)) == last(target_indices))
651+
return attempt
652+
end
658653
end
659654

660655
throw_not_std(arg)

0 commit comments

Comments
 (0)