Skip to content

Commit a9bb232

Browse files
committed
Collapse traces by default
1 parent c62699a commit a9bb232

5 files changed

Lines changed: 29 additions & 49 deletions

File tree

src/forward.jl

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -253,10 +253,6 @@ function evaluate(::Mult, arg1::KrD, arg2::BinaryOperation{Mult})
253253
end
254254

255255
function evaluate(::Mult, arg1::BinaryOperation{Mult}, arg2::KrD)
256-
if is_trace(BinaryOperation{Mult}(arg1, arg2))
257-
return BinaryOperation{Mult}(arg1, arg2)
258-
end
259-
260256
ci = indices_in_common(arg1.arg1, arg1.arg2)
261257

262258
if !isempty(ci) && !is_trace(arg2)
@@ -339,10 +335,6 @@ function evaluate(::Mult, arg1::UnaryOperation, arg2::KrD)
339335
end
340336

341337
function evaluate(::Mult, arg1::KrD, arg2::UnaryOp) where {UnaryOp<:UnaryOperation}
342-
if is_trace(BinaryOperation{Mult}(arg1, arg2))
343-
return BinaryOperation{Mult}(arg1, arg2)
344-
end
345-
346338
if can_contract(evaluate(arg1), evaluate(arg2.arg))
347339
return UnaryOp(evaluate(Mult(), evaluate(arg1), evaluate(arg2.arg)))
348340
end
@@ -374,10 +366,6 @@ function _multiply_with_krd(arg1::Union{Monomial,KrD}, arg2::KrD)
374366
return BinaryOperation{Mult}(arg1, arg2)
375367
end
376368

377-
if is_trace(BinaryOperation{Mult}(arg1, arg2))
378-
return BinaryOperation{Mult}(arg1, arg2)
379-
end
380-
381369
if is_trace(arg1) || is_trace(arg2)
382370
return BinaryOperation{Mult}(arg1, arg2)
383371
end
@@ -410,10 +398,6 @@ function evaluate(
410398
arg1::BinaryOperation{Op},
411399
arg2::KrD,
412400
) where {Op<:AdditiveOperation}
413-
if is_trace(BinaryOperation{Mult}(arg1, arg2))
414-
return BinaryOperation{Mult}(arg1, arg2)
415-
end
416-
417401
return evaluate(
418402
Op(),
419403
evaluate(Mult(), evaluate(arg1.arg1), evaluate(arg2)),
@@ -426,10 +410,6 @@ function evaluate(
426410
arg1::KrD,
427411
arg2::BinaryOperation{Op},
428412
) where {Op<:AdditiveOperation}
429-
if is_trace(BinaryOperation{Mult}(arg1, arg2))
430-
return BinaryOperation{Mult}(arg1, arg2)
431-
end
432-
433413
return evaluate(
434414
Op(),
435415
evaluate(Mult(), arg1, evaluate(arg2.arg1)),

src/ricci.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ function tr(arg::Tensor)
232232
throw(de)
233233
end
234234

235-
return BinaryOperation{Mult}(arg, KrD(flip(free_ids[2]), flip(free_ids[1])))
235+
return evaluate(BinaryOperation{Mult}(arg, KrD(flip(free_ids[2]), flip(free_ids[1]))))
236236
end
237237

238238
function Base.sum(arg::Tensor)

src/std.jl

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,9 @@ function _to_std_string(arg::Monomial)
214214
ids = get_indices(arg)
215215

216216
if length(ids) == 2
217-
if typeof(ids[1]) == Upper && typeof(ids[2]) == Lower
217+
if flip(ids[1]) == ids[2]
218+
return "tr(" * arg.id * ")"
219+
elseif typeof(ids[1]) == Upper && typeof(ids[2]) == Lower
218220
return arg.id
219221
elseif typeof(ids[1]) == Lower && typeof(ids[2]) == Upper
220222
return arg.id * ""
@@ -300,14 +302,16 @@ function get_contra_covariant_matrix(arg1::Tensor, arg2::Tensor)
300302
arg2_letters = [i.letter for i get_free_indices(arg2)]
301303
common_letter = intersect(arg1_letters, arg2_letters)
302304

303-
@assert length(common_letter) == 1
304-
305-
arg1_filt = filter(i->i.letter == first(common_letter), arg1_ids)
306-
arg2_filt = filter(i->i.letter == first(common_letter), arg2_ids)
305+
if length(common_letter) == 1
306+
arg1_filt = filter(i->i.letter == first(common_letter), arg1_ids)
307+
arg2_filt = filter(i->i.letter == first(common_letter), arg2_ids)
307308

308-
if typeof(first(arg1_filt)) == Lower && typeof(first(arg2_filt)) == Upper
309-
return (arg2, arg1)
310-
elseif typeof(first(arg1_filt)) == Upper && typeof(first(arg2_filt)) == Lower
309+
if typeof(first(arg1_filt)) == Lower && typeof(first(arg2_filt)) == Upper
310+
return (arg2, arg1)
311+
elseif typeof(first(arg1_filt)) == Upper && typeof(first(arg2_filt)) == Lower
312+
return (arg1, arg2)
313+
end
314+
elseif length(common_letter) == 2 # is a trace
311315
return (arg1, arg2)
312316
end
313317

@@ -377,6 +381,10 @@ function _to_std_string(arg::BinaryOperation{Mult})
377381
end
378382
end
379383

384+
if is_trace(arg)
385+
return "tr(" * _to_std_string(arg.arg1) * _to_std_string(arg.arg2) * ")"
386+
end
387+
380388
if isempty(target_indices) && (typeof(terms[1]) == KrD || typeof(terms[2]) == KrD)
381389
tensor = if typeof(first(terms)) == KrD
382390
last(terms)
@@ -386,9 +394,7 @@ function _to_std_string(arg::BinaryOperation{Mult})
386394

387395
tensor_free_ids = get_free_indices(tensor)
388396

389-
if length(tensor_free_ids) == 2
390-
return "tr(" * to_std_string(tensor) * ")"
391-
elseif length(tensor_free_ids) == 1
397+
if length(tensor_free_ids) == 1
392398
return "sum(" * _to_std_string(tensor) * ")"
393399
end
394400

test/ForwardTest.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -539,10 +539,7 @@ end
539539
A = Monomial("A", Upper(1), Lower(2))
540540
B = Monomial("B", Upper(2), Lower(3))
541541

542-
@test dc.evaluate(tr(A)) == BinaryOperation{dc.Mult}(
543-
Monomial("A", Upper(1), Lower(2)),
544-
KrD(Upper(2), Lower(1)),
545-
)
542+
@test dc.evaluate(tr(A)) == Monomial("A", Upper(2), Lower(2))
546543
@test equivalent(
547544
dc.evaluate(tr(A * B)),
548545
dc.BinaryOperation{dc.Mult}(A, Monomial("B", Upper(2), Lower(1))),

test/StdTest.jl

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,8 @@ end
109109
return dc.BinaryOperation{dc.Add}(l, r)
110110
end
111111

112-
trA = mul(Monomial("A", Upper(1), Lower(2)), KrD(Upper(2), Lower(1)))
112+
trA = Monomial("A", Upper(2), Lower(2))
113+
113114
A = Monomial("A", Upper(3), Lower(4))
114115
B = Monomial("B", Upper(4), Lower(5))
115116
x = Monomial("x", Upper(4))
@@ -119,23 +120,19 @@ end
119120
@test to_std_string(mul(A, trA)) == "tr(A)A"
120121
@test to_std_string(mul(mul(trA, A), x)) == "tr(A)Ax"
121122

122-
trAB = mul(
123-
mul(Monomial("A", Upper(1), Lower(2)), Monomial("B", Upper(2), Lower(3))),
124-
KrD(Upper(3), Lower(1)),
125-
)
123+
trAB = mul(Monomial("A", Upper(1), Lower(2)), Monomial("B", Upper(2), Lower(1)))
124+
126125
@test to_std_string(trAB) == "tr(AB)"
127126
@test to_std_string(mul(trAB, A)) == "tr(AB)A"
128127
@test to_std_string(mul(trAB, B)) == "tr(AB)B"
129128
@test to_std_string(mul(mul(trAB, A), x)) == "tr(AB)Ax"
130129

131-
trApB = mul(
132-
add(Monomial("A", Upper(1), Lower(2)), Monomial("B", Upper(1), Lower(2))),
133-
KrD(Upper(2), Lower(1)),
134-
)
135-
@test to_std_string(trApB) == "tr(A + B)"
136-
@test to_std_string(mul(trApB, A)) == "tr(A + B)A"
137-
@test to_std_string(mul(trApB, B)) == "tr(A + B)B"
138-
@test to_std_string(mul(mul(trApB, A), x)) == "tr(A + B)Ax"
130+
trApB = add(Monomial("A", Upper(2), Lower(2)), Monomial("B", Upper(2), Lower(2)))
131+
132+
@test to_std_string(trApB) == "tr(A) + tr(B)"
133+
@test to_std_string(mul(trApB, A)) == "(tr(A) + tr(B))A"
134+
@test to_std_string(mul(trApB, B)) == "(tr(A) + tr(B))B"
135+
@test to_std_string(mul(mul(trApB, A), x)) == "(tr(A) + tr(B))Ax"
139136
end
140137

141138
@testset "to_std_string output is correct with all covariant bilinar form-vector contraction" begin

0 commit comments

Comments
 (0)