Skip to content

Commit 2eab4e0

Browse files
committed
Add missing ir-to-Julia conversions and add tests
1 parent 56f25bf commit 2eab4e0

3 files changed

Lines changed: 37 additions & 5 deletions

File tree

src/ir.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,10 @@ function _get_variables(arg::Cos)
131131
return _get_variables(arg.arg)
132132
end
133133

134+
function _get_variables(arg::Log)
135+
return _get_variables(arg.arg)
136+
end
137+
134138
function _get_variables(arg::Add)
135139
return [_get_variables(arg.l); _get_variables(arg.r)]
136140
end

src/julia.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ function to_julia(arg::ir.Cos)
4343
return :(cos.($(to_julia(arg.arg))))
4444
end
4545

46+
function to_julia(arg::ir.Log)
47+
return :(log.($(to_julia(arg.arg))))
48+
end
49+
4650
function to_julia(arg::ir.Add)
4751
return :($(to_julia(arg.l)) + $(to_julia(arg.r)))
4852
end

test/JuliaTest.jl

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
using ForwardDiff
22

3-
using LinearAlgebra: diagm, I
3+
using LinearAlgebra: tr, diagm, I
44

55
@testset "test Julia function" begin
66
@matrix A B C
7-
@vector x y
7+
@vector x y z
88

99
= [
1010
0.055
@@ -18,6 +18,12 @@ using LinearAlgebra: diagm, I
1818
0.176
1919
]
2020

21+
= [
22+
0.578
23+
0.845
24+
0.711
25+
]
26+
2127
= [
2228
0.023 0.136 0.181
2329
0.443 0.132 0.576
@@ -36,6 +42,24 @@ using LinearAlgebra: diagm, I
3642
0.415 0.483 0.969
3743
]
3844

45+
@testset "function tr(x*x')" begin
46+
jfun = eval(to_std(tr(x*x'); format = dc.JuliaFunc()))
47+
48+
@test jfun(x̂) tr(x̂*')
49+
end
50+
51+
@testset "function tr(A*B'*C)" begin
52+
jfun = eval(to_std(tr(A*B'*C); format = dc.JuliaFunc()))
53+
54+
@test jfun(B̂, Ĉ, Â) tr(Â *' * Ĉ)
55+
end
56+
57+
@testset "function log(A' * x)" begin
58+
jfun = eval(to_std(log.(A' * x); format = dc.JuliaFunc()))
59+
60+
@test jfun(Â, x̂) log.(Â' * x̂)
61+
end
62+
3963
@testset "gradient of x'*x" begin
4064
jgrad = eval(to_std(gradient(x' * x, x); format = dc.JuliaFunc()))
4165

@@ -54,10 +78,10 @@ using LinearAlgebra: diagm, I
5478
@test jgrad(x̂) ForwardDiff.gradient(x -> cos(tr(x * x')), x̂)
5579
end
5680

57-
@testset "jacobian of sin(A * x + y)" begin
58-
jjac = eval(to_std(jacobian(sin.(A * x + y), x); format = dc.JuliaFunc()))
81+
@testset "jacobian of sin(A * x + y - z)" begin
82+
jjac = eval(to_std(jacobian(sin.(A * x + y - z), x); format = dc.JuliaFunc()))
5983

60-
@test jjac(Â, x̂, ŷ) ForwardDiff.jacobian(x -> sin.(Â * x + ŷ), x̂)
84+
@test jjac(Â, x̂, ŷ, ẑ) ForwardDiff.jacobian(x -> sin.(Â * x + -), x̂)
6185
end
6286

6387
@testset "jacobian of (A .* B) * C * x)' * x * x" begin

0 commit comments

Comments
 (0)