Skip to content

Commit e8587b2

Browse files
committed
Add trivial simplification of powers
1 parent 212ca12 commit e8587b2

2 files changed

Lines changed: 33 additions & 1 deletion

File tree

src/forward.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,9 @@ function evaluate(::Mult, arg1::Power, arg2::Tensor)
328328
end
329329

330330
function evaluate(::Mult, arg1::Power, arg2::Power)
331-
# TODO: Simplify based on exponents
331+
if isequal(arg1.base, arg2.base)
332+
return Power(arg1.base, arg1.exponent + arg2.exponent)
333+
end
332334

333335
return BinaryOperation{Mult}(evaluate(arg1), evaluate(arg2))
334336
end

test/ForwardTest.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,36 @@ end
587587
@test evaluate(op2) == Zero(Upper(1), Lower(2))
588588
end
589589

590+
@testset "evaluate product of powers with same base" begin
591+
l = dc.Power(Variable("a", Upper(1)), 2)
592+
r1 = dc.Power(Variable("a", Upper(1)), 3)
593+
r2 = dc.Power(Variable("a", Upper(1)), -2)
594+
r3 = dc.Power(Variable("a", Upper(1)), -2)
595+
596+
for r (r1, r2, r3)
597+
op1 = dc.BinaryOperation{dc.Mult}(l, r)
598+
op2 = dc.BinaryOperation{dc.Mult}(r, l)
599+
600+
@test evaluate(op1) == dc.Power(Variable("a", Upper(1)), l.exponent + r.exponent)
601+
@test evaluate(op2) == dc.Power(Variable("a", Upper(1)), l.exponent + r.exponent)
602+
end
603+
end
604+
605+
@testset "evaluate product of powers with differing base" begin
606+
l = dc.Power(Variable("a", Upper(1)), 2)
607+
r1 = dc.Power(Variable("b", Upper(1)), 3)
608+
r2 = dc.Power(Variable("a", Upper(2)), -2)
609+
r3 = dc.Power(Variable("a", Lower(1)), -2)
610+
611+
for r (r1, r2, r3)
612+
op1 = dc.BinaryOperation{dc.Mult}(l, r)
613+
op2 = dc.BinaryOperation{dc.Mult}(r, l)
614+
615+
@test evaluate(op1) == op1
616+
@test evaluate(op2) == op2
617+
end
618+
end
619+
590620
@testset "evaluate product of Zero and power" begin
591621
p = dc.Power(Variable("a", Upper(1)), 2)
592622
z = Zero(Upper(1))

0 commit comments

Comments
 (0)