Skip to content

Commit 0d52136

Browse files
committed
Consider only free indices for element-wise products
1 parent a86ee74 commit 0d52136

2 files changed

Lines changed: 5 additions & 3 deletions

File tree

src/forward.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,8 @@ function evaluate(::Mult, arg1::T, arg2::BinaryOperation{Mult}) where {T<:Real}
115115
end
116116

117117
function indices_in_common(arg1, arg2)
118-
arg1_indices = get_indices(arg1)
119-
arg2_indices = get_indices(arg2)
118+
arg1_indices = get_free_indices(arg1)
119+
arg2_indices = get_free_indices(arg2)
120120

121121
return intersect(arg1_indices, arg2_indices)
122122
end

test/StdStrTest.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,12 @@ end
8181
end
8282

8383
@testset "test Hessian in standard notation" begin
84-
@matrix A
84+
@matrix A B
8585
@vector x
8686

8787
@test to_std(hessian(x' * A * x, x)) == "Aᵀ + A"
8888
@test to_std(hessian(2 * x' * A * x, x)) == "2Aᵀ + 2A"
8989
@test to_std(hessian(2 * x' * x, x)) == "4I"
90+
@test to_std(hessian(sin(cos(x' * A * B' * x)), x)) ==
91+
"cos(cos(xᵀBAᵀx))((-1)sin(xᵀBAᵀx)(BAᵀ + ABᵀ) + (BAᵀx + ABᵀx)(-1)cos(xᵀBAᵀx)(xᵀABᵀ + xᵀBAᵀ)) + (-1)(-1)sin(xᵀBAᵀx)(BAᵀx + ABᵀx)sin(cos(xᵀBAᵀx))(-1)sin(xᵀBAᵀx)(xᵀABᵀ + xᵀBAᵀ)"
9092
end

0 commit comments

Comments
 (0)