@@ -144,58 +144,55 @@ function simplify(::Mult, arg1::BinaryOperation{Mult}, arg2::Tensor)
144144
145145 @assert ! isnothing (d)
146146
147- target_indices = eliminate_indices (vcat (get_free_indices (arg1), get_indices (arg2)))
148147 factors = collect_factors (arg1)
149- vector_factors = filter (f -> f != d, factors)
148+ other_factors = filter (f -> f != d, factors)
149+ all_factors = vcat (filter (f -> f != d, factors), collect_factors (arg2))
150150 reshaped = []
151151
152- for f ∈ factors
153- if isequal (f, d)
154- continue
155- end
156-
157- free_ids = get_free_indices (f)
158-
159- if isempty (free_ids)
160- push! (reshaped, f)
161- elseif length (free_ids) == 1 || length (free_ids) == 2
162- @assert length (target_indices) == 1
163-
164- vector_index =
165- only (get_free_indices (to_binary_operation (Mult (), vector_factors)))
166-
167- current_idx = intersect (free_ids, [vector_index])
168-
169- if ! isempty (current_idx)
170- f = update_index (
171- f,
172- vector_index,
173- only (target_indices);
174- allow_shape_change = true ,
175- )
152+ el1 = eliminated_indices ([get_free_indices (d); get_free_indices (arg2)])
153+ ic1 = indices_in_common (d, to_binary_operation (Mult (), other_factors))
154+ el2 = eliminated_indices (
155+ [
156+ get_free_indices (d);
157+ get_free_indices (to_binary_operation (Mult (), other_factors))
158+ ],
159+ )
160+ ic2 = indices_in_common (d, arg2)
161+
162+ if ! isempty (intersect (el1, ic1)) || ! isempty (intersect (el2, ic2))
163+
164+ for f ∈ all_factors
165+ free_ids = get_free_indices (f)
166+
167+ if isempty (free_ids)
168+ push! (reshaped, f)
169+ else
170+ if can_contract (d, f)
171+ push! (reshaped, evaluate (Mult (), d, f))
172+ elseif ! isempty (indices_in_common (d, f))
173+ common_index = only (indices_in_common (d, f))
174+ target_index = if first (d. indices) == common_index
175+ last (d. indices)
176+ else
177+ first (d. indices)
178+ end
179+
180+ f = update_index (
181+ f,
182+ common_index,
183+ target_index;
184+ allow_shape_change = true ,
185+ )
186+
187+ push! (reshaped, f)
188+ else
189+ push! (reshaped, f)
190+ end
176191 end
177-
178- push! (reshaped, f)
179- else
180- @assert false " Not implemented, please open an issue with your input"
181192 end
182- end
183-
184- arg2_ids = get_free_indices (arg2)
185193
186- if length (arg2_ids) == 1
187- arg2 = update_index (
188- arg2,
189- only (arg2_ids),
190- only (target_indices);
191- allow_shape_change = true ,
192- )
193- push! (reshaped, arg2)
194- else
195- @assert false " Not implemented, please open an issue with your input"
194+ return to_binary_operation (Mult (), reshaped)
196195 end
197-
198- return to_binary_operation (Mult (), reshaped)
199196 end
200197
201198 return op
0 commit comments