Skip to content

Commit da9fbb0

Browse files
committed
Add ir.get_variables
1 parent 94cbaf5 commit da9fbb0

1 file changed

Lines changed: 82 additions & 0 deletions

File tree

src/ir.jl

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,88 @@ struct PartialSum <: IR
7878
dim::Int
7979
end
8080

81+
function _get_variables(arg::Mat)
82+
if arg.id isa String
83+
return arg.id
84+
end
85+
86+
return nothing
87+
end
88+
89+
function _get_variables(arg::Vec)
90+
if arg.id isa String
91+
return arg.id
92+
end
93+
94+
return nothing
95+
end
96+
97+
function _get_variables(arg::Scal)
98+
if arg.id isa String
99+
return arg.id
100+
end
101+
102+
return nothing
103+
end
104+
105+
function _get_variables(arg::Real)
106+
return nothing
107+
end
108+
109+
function _get_variables(arg::Identity)
110+
return nothing
111+
end
112+
113+
function _get_variables(arg::Sin)
114+
return _get_variables(arg.arg)
115+
end
116+
117+
function _get_variables(arg::Cos)
118+
return _get_variables(arg.arg)
119+
end
120+
121+
function _get_variables(arg::Add)
122+
return [_get_variables(arg.l); _get_variables(arg.r)]
123+
end
124+
125+
function _get_variables(arg::Sub)
126+
return [_get_variables(arg.l); _get_variables(arg.r)]
127+
end
128+
129+
function _get_variables(arg::Product)
130+
return [_get_variables(arg.l); _get_variables(arg.r)]
131+
end
132+
133+
function _get_variables(arg::HadamardProduct)
134+
return [_get_variables(arg.l); _get_variables(arg.r)]
135+
end
136+
137+
function _get_variables(arg::Power)
138+
return [_get_variables(arg.base); _get_variables(arg.exponent)]
139+
end
140+
141+
function _get_variables(arg::Trace)
142+
return _get_variables(arg.arg)
143+
end
144+
145+
function _get_variables(arg::Diag)
146+
return _get_variables(arg.arg)
147+
end
148+
149+
function _get_variables(arg::Transpose)
150+
return _get_variables(arg.arg)
151+
end
152+
153+
function _get_variables(arg::Sum)
154+
return _get_variables(arg.arg)
155+
end
156+
157+
function get_variables(arg::IR)
158+
s = [_get_variables(arg);]
159+
160+
return filter(x -> !isnothing(x), unique(s))
161+
end
162+
81163
end
82164

83165
"""

0 commit comments

Comments
 (0)