Skip to content

Commit 3521576

Browse files
committed
Fix evaluate in simplify
1 parent 8900828 commit 3521576

2 files changed

Lines changed: 12 additions & 10 deletions

File tree

src/simplify.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,19 @@ function simplify(arg::Value)
77
end
88

99
function simplify(arg::UnaryOperation{Op}) where {Op}
10-
return UnaryOperation{Op}(simplify(arg.arg))
10+
return UnaryOperation{Op}(simplify(evaluate(arg.arg)))
1111
end
1212

1313
function simplify(arg::BinaryOperation{Op}) where {Op}
14-
return evaluate(simplify(Op(), simplify(arg.arg1), simplify(arg.arg2)))
14+
return evaluate(
15+
simplify(Op(), simplify(evaluate(arg.arg1)), simplify(evaluate(arg.arg2))),
16+
)
1517
end
1618

1719
function simplify(::Mult, arg1::Variable, arg2::Variable)
18-
return BinaryOperation{Mult}(arg1, arg2)
20+
return evaluate(
21+
BinaryOperation{Mult}(simplify(evaluate(arg1)), simplify(evaluate(arg2))),
22+
)
1923
end
2024

2125
function elementwise_indices(arg1, arg2)
@@ -162,9 +166,11 @@ function simplify(::Mult, arg1::BinaryOperation{Mult}, arg2::Variable)
162166
end
163167

164168
function simplify(::Mult, arg1::Value, arg2::Value)
165-
return evaluate(BinaryOperation{Mult}(arg1, arg2))
169+
return evaluate(
170+
BinaryOperation{Mult}(simplify(evaluate(arg1)), simplify(evaluate(arg2))),
171+
)
166172
end
167173

168174
function simplify(::Op, arg1::Value, arg2::Value) where {Op<:AdditiveOperation}
169-
return evaluate(BinaryOperation{Op}(arg1, arg2))
175+
return evaluate(BinaryOperation{Op}(simplify(evaluate(arg1)), simplify(evaluate(arg2))))
170176
end

test/ForwardTest.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -797,11 +797,7 @@ end
797797

798798
expected = dc.BinaryOperation{dc.Mult}(Variable("y", Lower(1)), Variable("x", Lower(1)))
799799

800-
# TODO: simplify should be sufficient here - remove evaluate
801-
@test equivalent(
802-
dc.simplify(dc.evaluate(dc.diff(e, Variable("z", Upper(9))))),
803-
expected,
804-
)
800+
@test equivalent(dc.simplify(dc.diff(e, Variable("z", Upper(9)))), expected)
805801

806802
expected = dc.BinaryOperation{dc.Mult}(Variable("y", Upper(1)), Variable("x", Upper(1)))
807803
@test equivalent(dc.simplify(dc.diff(e, Variable("z", Upper(9)))'), expected)

0 commit comments

Comments
 (0)