Skip to content

Commit c795167

Browse files
committed
Add Julia code output
1 parent da9fbb0 commit c795167

3 files changed

Lines changed: 102 additions & 5 deletions

File tree

src/DiffMatic.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,6 @@ include("ir.jl")
1010
include("simplify.jl")
1111
include("std.jl")
1212
include("forward.jl")
13+
include("julia.jl")
1314

1415
end

src/julia.jl

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
2+
function to_julia(arg::ir.Var)
3+
return Symbol(arg.id)
4+
end
5+
6+
function to_julia(arg::ir.Const)
7+
return :($(arg.value))
8+
end
9+
10+
function to_julia(arg::ir.Mat)
11+
if arg.id isa ir.Var
12+
return to_julia(arg.id)
13+
end
14+
15+
throw(RuntimeError("Unable to generate julia code for constant matrix of unknown size"))
16+
end
17+
18+
function to_julia(arg::ir.Vec)
19+
if arg.id isa ir.Var
20+
return to_julia(arg.id)
21+
end
22+
23+
throw(RuntimeError("Unable to generate julia code for constant vector of unknown size"))
24+
end
25+
26+
function to_julia(arg::ir.Scal)
27+
return to_julia(arg.id)
28+
end
29+
30+
function to_julia(arg::ir.Real)
31+
return arg
32+
end
33+
34+
function to_julia(arg::ir.Identity)
35+
return Symbol("I")
36+
end
37+
38+
function to_julia(arg::ir.Sin)
39+
return :(sin.($(to_julia(arg.arg))))
40+
end
41+
42+
function to_julia(arg::ir.Cos)
43+
return :(cos.($(to_julia(arg.arg))))
44+
end
45+
46+
function to_julia(arg::ir.Add)
47+
return :($(to_julia(arg.l)) + $(to_julia(arg.r)))
48+
end
49+
50+
function to_julia(arg::ir.Sub)
51+
return :($(to_julia(arg.l)) - $(to_julia(arg.r)))
52+
end
53+
54+
function to_julia(arg::ir.Product)
55+
return :($(to_julia(arg.l)) * $(to_julia(arg.r)))
56+
end
57+
58+
function to_julia(arg::ir.HadamardProduct)
59+
return :($(to_julia(arg.l)) .* $(to_julia(arg.r)))
60+
end
61+
62+
function to_julia(arg::ir.Power)
63+
return :($(to_julia(arg.base)) .^ $(to_julia(arg.exponent)))
64+
end
65+
66+
function to_julia(arg::ir.Trace)
67+
return :(tr($(to_julia(arg.arg))))
68+
end
69+
70+
function to_julia(arg::ir.Diag)
71+
return :(diagm($(to_julia(arg.arg))))
72+
end
73+
74+
function to_julia(arg::ir.Transpose)
75+
return :(transpose($(to_julia(arg.arg))))
76+
end
77+
78+
function to_julia(arg::ir.Sum)
79+
return :(sum($(to_julia(arg.arg))))
80+
end

src/std.jl

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -306,9 +306,6 @@ function to_standard(arg::BinaryOperation{Mult})
306306
throw_not_std(arg)
307307
end
308308

309-
struct Ir end
310-
struct StdStr end
311-
312309
function to_std_str(arg::ir.Mat)
313310
if arg.id isa String
314311
return arg.id
@@ -453,6 +450,10 @@ function standardize(arg)
453450
return standardized
454451
end
455452

453+
struct Ir end
454+
struct StdStr end
455+
struct Julia end
456+
456457
"""
457458
to_std(expr)
458459
@@ -472,14 +473,29 @@ function to_std(arg; format = StdStr())
472473
return _to_std(format, arg)
473474
end
474475

476+
function _to_std(format::Ir, arg)
477+
standardized = standardize(arg)
478+
479+
return to_ir(standardized)
480+
end
481+
475482
function _to_std(format::StdStr, arg)
476483
standardized = standardize(arg)
477484

478485
return to_std_str(to_ir(standardized))
479486
end
480487

481-
function _to_std(format::Ir, arg)
488+
function _to_std(format::Julia, arg)
482489
standardized = standardize(arg)
483490

484-
return to_ir(standardized)
491+
ir = to_ir(standardized)
492+
op = to_julia(ir)
493+
494+
variables = DiffMatic.ir.get_variables(ir)
495+
496+
return quote
497+
function derivative($(Symbol.(variables)...))
498+
return $(op)
499+
end
500+
end
485501
end

0 commit comments

Comments
 (0)