@@ -132,9 +132,12 @@ function get_next_letter(exprs...)
132132 indices = get_indices .(exprs)
133133
134134 letters = [index. letter for index ∈ Iterators. flatten (indices)]
135- max_letter = maximum (letters)
136135
137- return max_letter + 1
136+ if isempty (letters)
137+ return 1
138+ end
139+
140+ return maximum (letters) + 1
138141end
139142
140143function is_permutation (l:: AbstractArray{T} , r:: AbstractArray{T} ) where {T}
@@ -289,23 +292,78 @@ function Base.broadcasted(::typeof(^), arg1::Tensor, arg2::Int)
289292 return BinaryOperation {Pow} (arg1, arg2)
290293end
291294
295+ function replace_letters (arg:: BinaryOperation{Mult} , letter_map:: Dict )
296+ return BinaryOperation {Mult} (
297+ replace_letters (arg. arg1, letter_map),
298+ replace_letters (arg. arg2, letter_map),
299+ )
300+ end
301+
302+ function replace_letters (arg:: BinaryOperation{Pow} , letter_map:: Dict )
303+ return BinaryOperation {Mult} (
304+ replace_letters (arg. arg1, letter_map),
305+ replace_letters (arg. arg2, letter_map),
306+ )
307+ end
308+
309+ function replace_letters (
310+ arg:: BinaryOperation{Op} ,
311+ letter_map:: Dict ,
312+ ) where {Op<: AdditiveOperation }
313+ return BinaryOperation {Op} (
314+ replace_letters (arg. arg1, letter_map),
315+ replace_letters (arg. arg2, letter_map),
316+ )
317+ end
318+
319+ function replace_letters (arg:: Union{Monomial,Zero,KrD} , letter_map:: Dict )
320+ new_indices = LowerOrUpperIndex[]
321+
322+ for i ∈ arg. indices
323+ if haskey (letter_map, i. letter)
324+ push! (new_indices, same_to (i, letter_map[i. letter]))
325+ else
326+ push! (new_indices, i)
327+ end
328+ end
329+
330+ newarg = deepcopy (arg)
331+ empty! (newarg. indices)
332+
333+ for ni ∈ new_indices
334+ push! (newarg. indices, ni)
335+ end
336+
337+ return newarg
338+ end
339+
340+ function replace_letters (arg:: UnaryOperation{Op} , letter_map:: Dict ) where {Op}
341+ return UnaryOperation {Op} (replace_letters (arg. arg, letter_map))
342+ end
343+
344+ function replace_letters (arg:: Real , letter_map:: Dict )
345+ return arg
346+ end
347+
292348function Base.:(* )(arg1:: Tensor , arg2:: Real )
293349 return arg2 * arg1
294350end
295351
296352function Base.:(* )(arg1:: Value , arg2:: Tensor )
297353 arg1_indices, arg2_indices = unique .(get_indices .((arg1, arg2)))
298- intersecting_letters = intersect (get_letters (arg1_indices), get_letters (arg2_indices))
354+ intersecting_letters =
355+ unique (intersect (get_letters (arg1_indices), get_letters (arg2_indices)))
299356
300- for letter ∈ intersecting_letters
301- for index ∈ arg2_indices
302- if index. letter == letter
303- new_letter = get_next_letter (arg1, arg2)
304- arg2 = update_index (arg2, index, same_to (index, new_letter))
305- end
306- end
357+ new_letters = Dict ()
358+ next_letter = get_next_letter (arg1, arg2)
359+
360+ for l ∈ intersecting_letters
361+ new_letters[l] = next_letter
362+ next_letter += 1
307363 end
308364
365+ arg2 = replace_letters (arg2, new_letters)
366+
309367 arg1_free_indices = get_free_indices (arg1)
310368 arg2_free_indices = get_free_indices (arg2)
311369
@@ -321,26 +379,6 @@ function Base.:(*)(arg1::Value, arg2::Tensor)
321379 throw (DomainError (arg2, " Multiplication involving tensor \" $arg2 \" is ambiguous" ))
322380 end
323381
324- intersecting_letters =
325- intersect (get_letters (arg1_free_indices), get_letters (arg2_free_indices))
326-
327- if ! isempty (intersecting_letters)
328- if length (intersecting_letters) > 1 ||
329- arg1_free_indices[end ]. letter != arg2_free_indices[1 ]. letter
330- # Intersecting letters need updating if:
331- # - There are multiple intersecting letters
332- # - The intersecting letters are any other than the contracting indices
333-
334- for i ∈ arg1_free_indices
335- if i. letter ∈ intersecting_letters
336- arg1 = update_index (arg1, i, same_to (i, get_next_letter (arg1, arg2)))
337- end
338- end
339- end
340-
341- arg1_free_indices = get_free_indices (arg1)
342- end
343-
344382 # TODO : WETWET, simplify, add e.g. get_lower(arg::IndexList) and get_upper(arg::IndexLists)
345383 if typeof (arg1_free_indices[end ]) == Lower && typeof (arg2_free_indices[1 ]) == Upper
346384 new_letter = get_next_letter (arg1, arg2)
0 commit comments