Skip to content

Commit 4afd7ab

Browse files
committed
Move all Real:s during 'evaluate'
1 parent 12612d6 commit 4afd7ab

2 files changed

Lines changed: 21 additions & 3 deletions

File tree

src/forward.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,10 @@ function evaluate(::Mult, arg1::BinaryOperation{Mult}, arg2::Real)
7878
end
7979

8080
function evaluate(::Mult, arg1::Real, arg2::BinaryOperation{Mult})
81+
if arg1 == 1
82+
return arg2
83+
end
84+
8185
if arg2.arg1 isa Real
8286
return BinaryOperation{Mult}(arg1 * arg2.arg1, arg2.arg2)
8387
elseif arg2.arg2 isa Real
@@ -140,6 +144,10 @@ function evaluate(::Mult, arg1::Monomial, arg2::BinaryOperation{Mult})
140144
end
141145

142146
function evaluate(::Mult, arg1::BinaryOperation{Mult}, arg2::Monomial)
147+
if arg1.arg1 isa Real
148+
return BinaryOperation{Mult}(arg1.arg1, BinaryOperation{Mult}(arg1.arg2, arg2))
149+
end
150+
143151
is_elementwise = is_elementwise_multiplication(arg1.arg1, arg1.arg2)
144152
arg1_indices, arg2_indices = get_free_indices.((arg1, arg2))
145153

@@ -182,6 +190,16 @@ end
182190
function evaluate(::Mult, arg1::BinaryOperation{Mult}, arg2::BinaryOperation{Mult})
183191
if arg1.arg1 == -1 && arg2.arg1 == -1
184192
return BinaryOperation{Mult}(arg1.arg2, arg2.arg2)
193+
elseif arg1.arg1 == -1
194+
return BinaryOperation{Mult}(
195+
arg1.arg1,
196+
evaluate(BinaryOperation{Mult}(arg1.arg2, arg2)),
197+
)
198+
elseif arg2.arg1 == -1
199+
return BinaryOperation{Mult}(
200+
arg2.arg1,
201+
evaluate(BinaryOperation{Mult}(arg2.arg2, arg1)),
202+
)
185203
end
186204

187205
new_args = []

test/StdTest.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -347,8 +347,8 @@ end
347347
@test to_std_string(gradient(sum(2 * x), x)) == "2vec(1)"
348348
@test to_std_string(gradient(2 * sum(sin(x)), x)) == "2cos(x)"
349349
@test to_std_string(gradient(sum(2 * sin(x)), x)) == "2cos(x)"
350-
@test to_std_string(gradient(2 * sum(cos(A * x + y)), x)) == "2(-1)Aᵀsin(Ax + y)"
351-
@test to_std_string(gradient(sum(2 * cos(A * x + y)), x)) == "2(-1)Aᵀsin(Ax + y)"
350+
@test to_std_string(gradient(2 * sum(cos(A * x + y)), x)) == "(-2)Aᵀsin(Ax + y)"
351+
@test to_std_string(gradient(sum(2 * cos(A * x + y)), x)) == "(-2)Aᵀsin(Ax + y)"
352352
@test to_std_string(gradient(sum(x .^ 2), x)) == "2x"
353353
@test to_std_string(gradient(sum(x .^ 3), x)) == "3x²"
354354
@test to_std_string(gradient(sum((x + y) .^ 2), x)) == "2(x + y)"
@@ -374,7 +374,7 @@ end
374374
@matrix A B C X
375375
@vector x y z
376376

377-
@test to_std_string(derivative(sum(-y .* (X*z)), X)) == "z(-1)yᵀ"
377+
@test to_std_string(derivative(sum(-y .* (X*z)), X)) == "(-1)zyᵀ"
378378
@test to_std_string(derivative(sum((A .* B) * C * x), x)) == "vec(1)ᵀ(A ⊙ B)C"
379379
end
380380

0 commit comments

Comments
 (0)