Skip to content

Commit ea7ebe0

Browse files
committed
Make simplification of tensor products more general
1 parent 9da0101 commit ea7ebe0

1 file changed

Lines changed: 42 additions & 45 deletions

File tree

src/simplify.jl

Lines changed: 42 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -144,58 +144,55 @@ function simplify(::Mult, arg1::BinaryOperation{Mult}, arg2::Tensor)
144144

145145
@assert !isnothing(d)
146146

147-
target_indices = eliminate_indices(vcat(get_free_indices(arg1), get_indices(arg2)))
148147
factors = collect_factors(arg1)
149-
vector_factors = filter(f -> f != d, factors)
148+
other_factors = filter(f -> f != d, factors)
149+
all_factors = vcat(filter(f -> f != d, factors), collect_factors(arg2))
150150
reshaped = []
151151

152-
for f factors
153-
if isequal(f, d)
154-
continue
155-
end
156-
157-
free_ids = get_free_indices(f)
158-
159-
if isempty(free_ids)
160-
push!(reshaped, f)
161-
elseif length(free_ids) == 1 || length(free_ids) == 2
162-
@assert length(target_indices) == 1
163-
164-
vector_index =
165-
only(get_free_indices(to_binary_operation(Mult(), vector_factors)))
166-
167-
current_idx = intersect(free_ids, [vector_index])
168-
169-
if !isempty(current_idx)
170-
f = update_index(
171-
f,
172-
vector_index,
173-
only(target_indices);
174-
allow_shape_change = true,
175-
)
152+
el1 = eliminated_indices([get_free_indices(d); get_free_indices(arg2)])
153+
ic1 = indices_in_common(d, to_binary_operation(Mult(), other_factors))
154+
el2 = eliminated_indices(
155+
[
156+
get_free_indices(d);
157+
get_free_indices(to_binary_operation(Mult(), other_factors))
158+
],
159+
)
160+
ic2 = indices_in_common(d, arg2)
161+
162+
if !isempty(intersect(el1, ic1)) || !isempty(intersect(el2, ic2))
163+
164+
for f all_factors
165+
free_ids = get_free_indices(f)
166+
167+
if isempty(free_ids)
168+
push!(reshaped, f)
169+
else
170+
if can_contract(d, f)
171+
push!(reshaped, evaluate(Mult(), d, f))
172+
elseif !isempty(indices_in_common(d, f))
173+
common_index = only(indices_in_common(d, f))
174+
target_index = if first(d.indices) == common_index
175+
last(d.indices)
176+
else
177+
first(d.indices)
178+
end
179+
180+
f = update_index(
181+
f,
182+
common_index,
183+
target_index;
184+
allow_shape_change = true,
185+
)
186+
187+
push!(reshaped, f)
188+
else
189+
push!(reshaped, f)
190+
end
176191
end
177-
178-
push!(reshaped, f)
179-
else
180-
@assert false "Not implemented, please open an issue with your input"
181192
end
182-
end
183-
184-
arg2_ids = get_free_indices(arg2)
185193

186-
if length(arg2_ids) == 1
187-
arg2 = update_index(
188-
arg2,
189-
only(arg2_ids),
190-
only(target_indices);
191-
allow_shape_change = true,
192-
)
193-
push!(reshaped, arg2)
194-
else
195-
@assert false "Not implemented, please open an issue with your input"
194+
return to_binary_operation(Mult(), reshaped)
196195
end
197-
198-
return to_binary_operation(Mult(), reshaped)
199196
end
200197

201198
return op

0 commit comments

Comments
 (0)