629629function to_standard (arg:: BinaryOperation{Mult} )
630630 target_indices = unique (get_free_indices (arg))
631631 target_len = length (target_indices)
632+ is_scalar = isempty (target_indices)
632633
633634 if length (target_indices) > 2
634635 throw_not_std (arg)
@@ -637,24 +638,18 @@ function to_standard(arg::BinaryOperation{Mult})
637638 l = to_standard (arg. arg1)
638639 r = to_standard (arg. arg2)
639640
640- attempt = BinaryOperation {Mult} (l, r)
641- if get_free_indices (attempt) == target_indices
642- return attempt
643- end
644-
645- attempt = BinaryOperation {Mult} (adjoint (l), r)
646- if get_free_indices (attempt) == target_indices
647- return attempt
648- end
649-
650- attempt = BinaryOperation {Mult} (l, adjoint (r))
651- if get_free_indices (attempt) == target_indices
652- return attempt
653- end
654-
655- attempt = BinaryOperation {Mult} (adjoint (l), adjoint (r))
656- if get_free_indices (attempt) == target_indices
657- return attempt
641+ attempts = (
642+ BinaryOperation {Mult} (l, r),
643+ BinaryOperation {Mult} (adjoint (l), r),
644+ BinaryOperation {Mult} (l, adjoint (r)),
645+ BinaryOperation {Mult} (adjoint (l), adjoint (r)),
646+ )
647+
648+ for attempt ∈ attempts
649+ if length (get_free_indices (attempt)) == target_len &&
650+ (is_scalar || last (get_free_indices (attempt)) == last (target_indices))
651+ return attempt
652+ end
658653 end
659654
660655 throw_not_std (arg)
0 commit comments