Skip to content

Commit c78dae4

Browse files
committed
Replace all overlapping letters in Base.(*)
1 parent a9bb232 commit c78dae4

1 file changed

Lines changed: 68 additions & 30 deletions

File tree

src/ricci.jl

Lines changed: 68 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,12 @@ function get_next_letter(exprs...)
132132
indices = get_indices.(exprs)
133133

134134
letters = [index.letter for index Iterators.flatten(indices)]
135-
max_letter = maximum(letters)
136135

137-
return max_letter + 1
136+
if isempty(letters)
137+
return 1
138+
end
139+
140+
return maximum(letters) + 1
138141
end
139142

140143
function is_permutation(l::AbstractArray{T}, r::AbstractArray{T}) where {T}
@@ -289,23 +292,78 @@ function Base.broadcasted(::typeof(^), arg1::Tensor, arg2::Int)
289292
return BinaryOperation{Pow}(arg1, arg2)
290293
end
291294

295+
function replace_letters(arg::BinaryOperation{Mult}, letter_map::Dict)
296+
return BinaryOperation{Mult}(
297+
replace_letters(arg.arg1, letter_map),
298+
replace_letters(arg.arg2, letter_map),
299+
)
300+
end
301+
302+
function replace_letters(arg::BinaryOperation{Pow}, letter_map::Dict)
303+
return BinaryOperation{Mult}(
304+
replace_letters(arg.arg1, letter_map),
305+
replace_letters(arg.arg2, letter_map),
306+
)
307+
end
308+
309+
function replace_letters(
310+
arg::BinaryOperation{Op},
311+
letter_map::Dict,
312+
) where {Op<:AdditiveOperation}
313+
return BinaryOperation{Op}(
314+
replace_letters(arg.arg1, letter_map),
315+
replace_letters(arg.arg2, letter_map),
316+
)
317+
end
318+
319+
function replace_letters(arg::Union{Monomial,Zero,KrD}, letter_map::Dict)
320+
new_indices = LowerOrUpperIndex[]
321+
322+
for i arg.indices
323+
if haskey(letter_map, i.letter)
324+
push!(new_indices, same_to(i, letter_map[i.letter]))
325+
else
326+
push!(new_indices, i)
327+
end
328+
end
329+
330+
newarg = deepcopy(arg)
331+
empty!(newarg.indices)
332+
333+
for ni new_indices
334+
push!(newarg.indices, ni)
335+
end
336+
337+
return newarg
338+
end
339+
340+
function replace_letters(arg::UnaryOperation{Op}, letter_map::Dict) where {Op}
341+
return UnaryOperation{Op}(replace_letters(arg.arg, letter_map))
342+
end
343+
344+
function replace_letters(arg::Real, letter_map::Dict)
345+
return arg
346+
end
347+
292348
function Base.:(*)(arg1::Tensor, arg2::Real)
293349
return arg2 * arg1
294350
end
295351

296352
function Base.:(*)(arg1::Value, arg2::Tensor)
297353
arg1_indices, arg2_indices = unique.(get_indices.((arg1, arg2)))
298-
intersecting_letters = intersect(get_letters(arg1_indices), get_letters(arg2_indices))
354+
intersecting_letters =
355+
unique(intersect(get_letters(arg1_indices), get_letters(arg2_indices)))
299356

300-
for letter intersecting_letters
301-
for index arg2_indices
302-
if index.letter == letter
303-
new_letter = get_next_letter(arg1, arg2)
304-
arg2 = update_index(arg2, index, same_to(index, new_letter))
305-
end
306-
end
357+
new_letters = Dict()
358+
next_letter = get_next_letter(arg1, arg2)
359+
360+
for l intersecting_letters
361+
new_letters[l] = next_letter
362+
next_letter += 1
307363
end
308364

365+
arg2 = replace_letters(arg2, new_letters)
366+
309367
arg1_free_indices = get_free_indices(arg1)
310368
arg2_free_indices = get_free_indices(arg2)
311369

@@ -321,26 +379,6 @@ function Base.:(*)(arg1::Value, arg2::Tensor)
321379
throw(DomainError(arg2, "Multiplication involving tensor \"$arg2\" is ambiguous"))
322380
end
323381

324-
intersecting_letters =
325-
intersect(get_letters(arg1_free_indices), get_letters(arg2_free_indices))
326-
327-
if !isempty(intersecting_letters)
328-
if length(intersecting_letters) > 1 ||
329-
arg1_free_indices[end].letter != arg2_free_indices[1].letter
330-
# Intersecting letters need updating if:
331-
# - There are multiple intersecting letters
332-
# - The intersecting letters are any other than the contracting indices
333-
334-
for i arg1_free_indices
335-
if i.letter intersecting_letters
336-
arg1 = update_index(arg1, i, same_to(i, get_next_letter(arg1, arg2)))
337-
end
338-
end
339-
end
340-
341-
arg1_free_indices = get_free_indices(arg1)
342-
end
343-
344382
# TODO: WETWET, simplify, add e.g. get_lower(arg::IndexList) and get_upper(arg::IndexLists)
345383
if typeof(arg1_free_indices[end]) == Lower && typeof(arg2_free_indices[1]) == Upper
346384
new_letter = get_next_letter(arg1, arg2)

0 commit comments

Comments
 (0)