@@ -135,33 +135,46 @@ function simplify(::Mult, arg1::Tensor, arg2::BinaryOperation{Mult})
135135end
136136
137137function simplify (:: Mult , arg1:: BinaryOperation{Mult} , arg2:: Tensor )
138- if is_diag (arg1) && ! is_elementwise_multiplication (arg1, arg2)
138+ op = BinaryOperation {Mult} (arg1, arg2)
139+
140+ if is_diag (arg1) &&
141+ ! is_elementwise_multiplication (arg1, arg2) &&
142+ length (get_free_indices (op)) == 1
139143 d = get_diag_delta (arg1)
140144
141145 @assert ! isnothing (d)
142146
143147 target_indices = eliminate_indices (vcat (get_free_indices (arg1), get_indices (arg2)))
144148 factors = collect_factors (arg1)
149+ vector_factors = filter (f -> f != d, factors)
145150 reshaped = []
146151
147152 for f ∈ factors
148- if f isa KrD
153+ if isequal (f, d)
149154 continue
150155 end
151156
152157 free_ids = get_free_indices (f)
158+
153159 if isempty (free_ids)
154160 push! (reshaped, f)
155- elseif length (free_ids) == 1
161+ elseif length (free_ids) == 1 || length (free_ids) == 2
156162 @assert length (target_indices) == 1
157163
158- current_idx = intersect (free_ids, get_free_indices (d))
159- f = update_index (
160- f,
161- only (current_idx),
162- only (target_indices);
163- allow_shape_change = true ,
164- )
164+ vector_index =
165+ only (get_free_indices (to_binary_operation (Mult (), vector_factors)))
166+
167+ current_idx = intersect (free_ids, [vector_index])
168+
169+ if ! isempty (current_idx)
170+ f = update_index (
171+ f,
172+ vector_index,
173+ only (target_indices);
174+ allow_shape_change = true ,
175+ )
176+ end
177+
165178 push! (reshaped, f)
166179 else
167180 @assert false " Not implemented, please open an issue with your input"
@@ -185,7 +198,7 @@ function simplify(::Mult, arg1::BinaryOperation{Mult}, arg2::Tensor)
185198 return to_binary_operation (Mult (), reshaped)
186199 end
187200
188- return BinaryOperation {Mult} (arg1, arg2)
201+ return op
189202end
190203
191204function simplify (:: Mult , arg1:: BinaryOperation{Mult} , arg2:: BinaryOperation{Mult} )
0 commit comments