Skip to content

Commit 0bd8849

Browse files
committed
Add intermediate representation
1 parent 6201163 commit 0bd8849

5 files changed

Lines changed: 425 additions & 247 deletions

File tree

docs/src/index.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@ H = hessian(expr, x)
4040
# output
4141
A₇⁶ + A⁶₇
4242
```
43-
Convert the gradient into standard notation using `to_std_string`:
43+
Convert the gradient into standard notation using `to_std`:
4444
```jldoctest intro
45-
to_std_string(g)
45+
to_std(g)
4646
4747
# output
4848
@@ -51,7 +51,7 @@ to_std_string(g)
5151

5252
Convert the the Hessian into standard notation:
5353
```jldoctest intro
54-
to_std_string(H)
54+
to_std(H)
5555
5656
# output
5757
@@ -61,7 +61,7 @@ to_std_string(H)
6161
Jacobians can be computed with `jacobian`:
6262

6363
```jldoctest intro
64-
to_std_string(jacobian(A * x, x))
64+
to_std(jacobian(A * x, x))
6565
6666
# output
6767
@@ -71,16 +71,16 @@ to_std_string(jacobian(A * x, x))
7171
The method `derivative` can be used to compute arbitrary derivatives.
7272

7373
```jldoctest intro
74-
to_std_string(derivative(tr(A), A))
74+
to_std(derivative(tr(A), A))
7575
7676
# output
7777
7878
"I"
7979
```
80-
The method `to_std_string` will throw an exception when given an expression that that cannot be converted to
80+
The method `to_std` will throw an exception when given an expression that that cannot be converted to
8181
standard notation:
8282
```jldoctest intro
83-
to_std_string(derivative(A, A))
83+
to_std(derivative(A, A))
8484
8585
# output
8686

src/DiffMatic.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
module DiffMatic
66

77
include("index.jl")
8+
include("ir.jl")
89
include("ricci.jl")
910
include("simplify.jl")
1011
include("std.jl")

src/ir.jl

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
module ir
2+
3+
abstract type IR end
4+
5+
struct Mat <: IR
6+
id::Union{String,Real}
7+
end
8+
9+
struct Vec <: IR
10+
id::Union{String,Real}
11+
end
12+
13+
struct Scal <: IR
14+
id::Union{String,Real}
15+
end
16+
17+
struct Identity <: IR end
18+
19+
struct Sin <: IR
20+
arg::IR
21+
end
22+
23+
struct Cos <: IR
24+
arg::IR
25+
end
26+
27+
struct Add <: IR
28+
l::IR
29+
r::IR
30+
end
31+
32+
struct Sub <: IR
33+
l::IR
34+
r::IR
35+
end
36+
37+
struct Product <: IR
38+
l::IR
39+
r::IR
40+
end
41+
42+
struct HadamardProduct <: IR
43+
l::IR
44+
r::IR
45+
end
46+
47+
struct Power <: IR
48+
base::IR
49+
exponent::Union{Int,Rational{Int}}
50+
end
51+
52+
struct Trace <: IR
53+
arg::IR
54+
end
55+
56+
struct Diag <: IR
57+
arg::IR
58+
end
59+
60+
struct Transpose <: IR
61+
arg::IR
62+
end
63+
64+
struct Sum <: IR
65+
arg::IR
66+
end
67+
68+
end

0 commit comments

Comments
 (0)