@@ -71,7 +71,7 @@ function to_binary_operation(op::Op, terms::AbstractArray) where {Op}
7171 return BinaryOperation {Op} (to_binary_operation (op, terms[1 : (end - 1 )]), terms[end ])
7272end
7373
74- function simplify (:: Mult , arg1:: BinaryOperation{Mult} , arg2:: KrD )
74+ function simplify (:: Mult , arg1:: BinaryOperation{Mult} , arg2:: Literal )
7575 if is_diag (arg1)
7676 d = get_diag_delta (arg1)
7777
@@ -103,21 +103,16 @@ function simplify(::Mult, arg1::BinaryOperation{Mult}, arg2::KrD)
103103 return to_binary_operation (Mult (), reshaped)
104104 end
105105
106- if is_trace (arg2)
107- s = first (arg2. indices)
108-
106+ if can_contract (arg1, arg2) && length (get_free_indices (arg2)) == 1
109107 elwise_ids = elementwise_indices (arg1. arg1, arg1. arg2)
110- last_index = get_last_letter (union (get_free_indices (arg1), get_free_indices (arg2)))
111-
112- if ! isempty (elwise_ids)
113- if s ∈ elwise_ids || flip (s) ∈ elwise_ids
114- if last_index ∈ get_free_indices (arg1. arg1)
115- return evaluate (BinaryOperation {Mult} (arg1. arg1, adjoint (arg1. arg2)))
116- elseif last_index ∈ get_free_indices (arg1. arg2)
117- return evaluate (BinaryOperation {Mult} (adjoint (arg1. arg1), arg1. arg2))
118- end
119-
120- @assert false " Unreachable"
108+ remaining_index =
109+ eliminate_indices (union (get_free_indices (arg1), get_free_indices (arg2)))
110+
111+ if length (elwise_ids) == 1 && length (remaining_index) == 1
112+ if only (remaining_index) ∈ get_free_indices (arg1. arg1)
113+ return evaluate (BinaryOperation {Mult} (arg1. arg1, adjoint (arg1. arg2)))
114+ elseif only (remaining_index) ∈ get_free_indices (arg1. arg2)
115+ return evaluate (BinaryOperation {Mult} (adjoint (arg1. arg1), arg1. arg2))
121116 end
122117 end
123118 end
0 commit comments