Skip to content

Commit a521515

Browse files
committed
Add simplification for trivial quotients
1 parent 44bd035 commit a521515

2 files changed

Lines changed: 25 additions & 0 deletions

File tree

src/forward.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,14 @@ function evaluate(::Mult, arg1::Real, arg2::Zero)
545545
return evaluate(arg2)
546546
end
547547

548+
function evaluate(::Div, arg1::Value, arg2::Value)
549+
if arg1 == arg2
550+
return Literal(1, get_free_indices(arg1)...)
551+
end
552+
553+
return BinaryOperation{Div}(arg1, arg2)
554+
end
555+
548556
function evaluate(::Add, arg1::Zero, arg2::Zero)
549557
@assert is_permutation(arg1, arg2)
550558

test/ForwardTest.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,23 @@ end
8181
dc.BinaryOperation{dc.Div}(Variable("x", Upper(2)), l)
8282
end
8383

84+
@testset "evaluate trivially simplifiable div" begin
85+
x = Variable("x", Upper(1))
86+
d = KrD(Upper(2), Lower(1))
87+
88+
function prod(l, r)
89+
return dc.BinaryOperation{dc.Mult}(l, r)
90+
end
91+
92+
function div(n, d)
93+
return dc.BinaryOperation{dc.Div}(n, d)
94+
end
95+
96+
@test dc.evaluate(div(prod(x, d), prod(x, d))) == Literal(1, Upper(2))
97+
@test dc.evaluate(div(prod(x, d), prod(d, x))) == Literal(1, Upper(2))
98+
@test dc.evaluate(div(prod(d, x), prod(x, d))) == Literal(1, Upper(2))
99+
end
100+
84101
@testset "evaluate BinaryOperation{AdditiveOperation} Matrix and KrD" begin
85102
X = Variable("X", Upper(2), Lower(3))
86103
d = KrD(Upper(2), Lower(3))

0 commit comments

Comments
 (0)