Skip to content

Commit 7a37de1

Browse files
committed
Move to_std_str to separate file
1 parent 877fe01 commit 7a37de1

3 files changed

Lines changed: 140 additions & 135 deletions

File tree

src/DiffMatic.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,6 @@ include("simplify.jl")
1111
include("std.jl")
1212
include("forward.jl")
1313
include("julia.jl")
14+
include("stdstr.jl")
1415

1516
end

src/std.jl

Lines changed: 0 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -309,141 +309,6 @@ function to_standard(arg::BinaryOperation{Mult})
309309
throw_not_std(arg)
310310
end
311311

312-
function to_std_str(arg::ir.Mat)
313-
if arg.id isa ir.Var
314-
return to_std_str(arg.id)
315-
end
316-
317-
return "mat(" * to_std_str(arg.id) * ")"
318-
end
319-
320-
function to_std_str(arg::ir.Vec)
321-
if arg.id isa ir.Var
322-
return to_std_str(arg.id)
323-
end
324-
325-
return "vec(" * to_std_str(arg.id) * ")"
326-
end
327-
328-
function to_std_str(arg::ir.Scal)
329-
return to_std_str(arg.id)
330-
end
331-
332-
function to_std_str(arg::ir.Var)
333-
return arg.id
334-
end
335-
336-
function to_std_str(arg::ir.Const)
337-
return to_std_str(arg.value)
338-
end
339-
340-
function to_std_str(arg::Real)
341-
out = string(arg)
342-
343-
if arg < 0
344-
out = "(" * out * ")"
345-
end
346-
347-
return out
348-
end
349-
350-
function to_std_str(arg::Rational)
351-
out = string(arg)
352-
353-
return "(" * out * ")"
354-
end
355-
356-
function to_std_str(arg::ir.Identity)
357-
return "I"
358-
end
359-
360-
function to_std_str(arg::ir.Abs)
361-
return "abs(" * to_std_str(arg.arg) * ")"
362-
end
363-
364-
function to_std_str(arg::ir.Sgn)
365-
return "sgn(" * to_std_str(arg.arg) * ")"
366-
end
367-
368-
function to_std_str(arg::ir.Sin)
369-
return "sin(" * to_std_str(arg.arg) * ")"
370-
end
371-
372-
function to_std_str(arg::ir.Cos)
373-
return "cos(" * to_std_str(arg.arg) * ")"
374-
end
375-
376-
function parenthesize(f, arg::ir.Add)
377-
return "(" * f(arg) * ")"
378-
end
379-
380-
function parenthesize(f, arg::ir.Sub)
381-
return "(" * f(arg) * ")"
382-
end
383-
384-
function parenthesize(f, arg::ir.HadamardProduct)
385-
return "(" * f(arg.l) * "" * f(arg.r) * ")"
386-
end
387-
388-
function parenthesize(f, arg)
389-
return f(arg)
390-
end
391-
392-
function to_std_str(arg::ir.Add)
393-
return parenthesize(to_std_str, arg.l) * " + " * parenthesize(to_std_str, arg.r)
394-
end
395-
396-
function to_std_str(arg::ir.Sub)
397-
return parenthesize(to_std_str, arg.l) * " - " * parenthesize(to_std_str, arg.r)
398-
end
399-
400-
function to_std_str(arg::ir.Product)
401-
return parenthesize(to_std_str, arg.l) * parenthesize(to_std_str, arg.r)
402-
end
403-
404-
function to_std_str(arg::ir.HadamardProduct)
405-
return to_std_str(arg.l) * "" * to_std_str(arg.r)
406-
end
407-
408-
function to_std_str(arg::ir.Power)
409-
out = to_std_str(arg.base)
410-
411-
if arg.base isa ir.Product ||
412-
arg.base isa ir.HadamardProduct ||
413-
arg.base isa ir.Add ||
414-
arg.base isa ir.Sub
415-
out = "(" * out * ")"
416-
end
417-
418-
return out * "^" * to_std_str(arg.exponent)
419-
end
420-
421-
function to_std_str(arg::ir.Trace)
422-
return "tr(" * to_std_str(arg.arg) * ")"
423-
end
424-
425-
function to_std_str(arg::ir.Diag)
426-
return "diag(" * to_std_str(arg.arg) * ")"
427-
end
428-
429-
function to_std_str(arg::ir.Transpose)
430-
return parenthesize(to_std_str, arg.arg) * ""
431-
end
432-
433-
function to_std_str(arg::ir.Sum)
434-
return "sum(" * to_std_str(arg.arg) * ")"
435-
end
436-
437-
function to_std_str(arg::ir.PartialSum)
438-
if arg.dim == 1
439-
return "vec(1)ᵀ" * to_std_str(arg.arg)
440-
elseif arg.dim == 2
441-
return to_std_str(arg.arg) * "vec(1)"
442-
end
443-
444-
throw(RuntimeError("Encountered a sum over an unsupported index"))
445-
end
446-
447312
function standardize(arg)
448313
arg = simplify(arg)
449314
free_indices = unique(get_free_indices(arg))

src/stdstr.jl

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
# This Source Code Form is subject to the terms of the Mozilla Public
2+
# License, v. 2.0. If a copy of the MPL was not distributed with this
3+
# file, You can obtain one at https://mozilla.org/MPL/2.0/.
4+
5+
6+
function to_std_str(arg::ir.Mat)
7+
if arg.id isa ir.Var
8+
return to_std_str(arg.id)
9+
end
10+
11+
return "mat(" * to_std_str(arg.id) * ")"
12+
end
13+
14+
function to_std_str(arg::ir.Vec)
15+
if arg.id isa ir.Var
16+
return to_std_str(arg.id)
17+
end
18+
19+
return "vec(" * to_std_str(arg.id) * ")"
20+
end
21+
22+
function to_std_str(arg::ir.Scal)
23+
return to_std_str(arg.id)
24+
end
25+
26+
function to_std_str(arg::ir.Var)
27+
return arg.id
28+
end
29+
30+
function to_std_str(arg::ir.Const)
31+
return to_std_str(arg.value)
32+
end
33+
34+
function to_std_str(arg::Real)
35+
out = string(arg)
36+
37+
if arg < 0
38+
out = "(" * out * ")"
39+
end
40+
41+
return out
42+
end
43+
44+
function to_std_str(arg::Rational)
45+
out = string(arg)
46+
47+
return "(" * out * ")"
48+
end
49+
50+
function to_std_str(arg::ir.Identity)
51+
return "I"
52+
end
53+
54+
function to_std_str(arg::ir.Abs)
55+
return "abs(" * to_std_str(arg.arg) * ")"
56+
end
57+
58+
function to_std_str(arg::ir.Sgn)
59+
return "sgn(" * to_std_str(arg.arg) * ")"
60+
end
61+
62+
function to_std_str(arg::ir.Sin)
63+
return "sin(" * to_std_str(arg.arg) * ")"
64+
end
65+
66+
function to_std_str(arg::ir.Cos)
67+
return "cos(" * to_std_str(arg.arg) * ")"
68+
end
69+
70+
function parenthesize(f, arg::ir.Add)
71+
return "(" * f(arg) * ")"
72+
end
73+
74+
function parenthesize(f, arg::ir.Sub)
75+
return "(" * f(arg) * ")"
76+
end
77+
78+
function parenthesize(f, arg::ir.HadamardProduct)
79+
return "(" * f(arg.l) * "" * f(arg.r) * ")"
80+
end
81+
82+
function parenthesize(f, arg)
83+
return f(arg)
84+
end
85+
86+
function to_std_str(arg::ir.Add)
87+
return parenthesize(to_std_str, arg.l) * " + " * parenthesize(to_std_str, arg.r)
88+
end
89+
90+
function to_std_str(arg::ir.Sub)
91+
return parenthesize(to_std_str, arg.l) * " - " * parenthesize(to_std_str, arg.r)
92+
end
93+
94+
function to_std_str(arg::ir.Product)
95+
return parenthesize(to_std_str, arg.l) * parenthesize(to_std_str, arg.r)
96+
end
97+
98+
function to_std_str(arg::ir.HadamardProduct)
99+
return to_std_str(arg.l) * "" * to_std_str(arg.r)
100+
end
101+
102+
function to_std_str(arg::ir.Power)
103+
out = to_std_str(arg.base)
104+
105+
if arg.base isa ir.Product ||
106+
arg.base isa ir.HadamardProduct ||
107+
arg.base isa ir.Add ||
108+
arg.base isa ir.Sub
109+
out = "(" * out * ")"
110+
end
111+
112+
return out * "^" * to_std_str(arg.exponent)
113+
end
114+
115+
function to_std_str(arg::ir.Trace)
116+
return "tr(" * to_std_str(arg.arg) * ")"
117+
end
118+
119+
function to_std_str(arg::ir.Diag)
120+
return "diag(" * to_std_str(arg.arg) * ")"
121+
end
122+
123+
function to_std_str(arg::ir.Transpose)
124+
return parenthesize(to_std_str, arg.arg) * ""
125+
end
126+
127+
function to_std_str(arg::ir.Sum)
128+
return "sum(" * to_std_str(arg.arg) * ")"
129+
end
130+
131+
function to_std_str(arg::ir.PartialSum)
132+
if arg.dim == 1
133+
return "vec(1)ᵀ" * to_std_str(arg.arg)
134+
elseif arg.dim == 2
135+
return to_std_str(arg.arg) * "vec(1)"
136+
end
137+
138+
throw(RuntimeError("Encountered a sum over an unsupported index"))
139+
end

0 commit comments

Comments
 (0)