@@ -35,64 +35,16 @@ def run_interpreter(proc, kwargs):
35
35
Interpreter (proc , kwargs )
36
36
37
37
38
- # context is global
39
- ctxt = defaultdict (dict )
40
-
41
38
class Interpreter :
42
39
def __init__ (self , proc , kwargs , use_randomization = False ):
43
- assert isinstance (proc , LoopIR .proc )
44
-
45
- proc = ParallelAnalysis ().run (proc )
46
- proc = PrecisionAnalysis ().run (proc ) # TODO: need this?
47
- proc = WindowAnalysis ().apply_proc (proc )
48
- proc = MemoryAnalysis ().run (proc ) # TODO: need this?
40
+ if not isinstance (proc , LoopIR .proc ):
41
+ raise TypeError (f"Expected { proc .name } to be of type proc" )
49
42
50
- self .proc = proc
51
43
self .env = ChainMap ()
52
44
self .use_randomization = use_randomization
45
+ self .ctxt = defaultdict (dict )
53
46
54
- # type check args
55
- for a in proc .args :
56
- if not str (a .name ) in kwargs :
57
- raise TypeError (f"expected argument '{ a .name } ' to be supplied" )
58
-
59
- if a .type is T .size :
60
- if not is_pos_int (kwargs [str (a .name )]):
61
- raise TypeError (
62
- f"expected size '{ a .name } ' to have positive integer value"
63
- )
64
- self .env [a .name ] = kwargs [str (a .name )]
65
- elif a .type is T .index :
66
- if type (kwargs [str (a .name )]) is not int :
67
- raise TypeError (
68
- f"expected index variable '{ a .name } ' to be an integer"
69
- )
70
- self .env [a .name ] = kwargs [str (a .name )]
71
- elif a .type is T .bool :
72
- if type (kwargs [str (a .name )]) is not bool :
73
- raise TypeError (f"expected bool variable '{ a .name } ' to be a bool" )
74
- self .env [a .name ] = kwargs [str (a .name )]
75
- elif a .type is T .stride :
76
- if type (kwargs [str (a .name )]) is not int :
77
- raise TypeError (
78
- f"expected stride variable '{ a .name } ' to be an integer"
79
- )
80
- self .env [a .name ] = kwargs [str (a .name )]
81
- else :
82
- self .typecheck_input_buffer (a , kwargs )
83
- self .env [a .name ] = kwargs [str (a .name )]
84
-
85
- # evaluate preconditions
86
- for pred in proc .preds :
87
- if isinstance (pred , LoopIR .Const ):
88
- continue
89
- else :
90
- assert self .eval_e (pred ), "precondition not satisfied"
91
-
92
- # eval statements
93
- self .env = self .env .new_child ()
94
- self .eval_stmts (proc .body )
95
- self .env = self .env .parents
47
+ self .eval_proc (proc , kwargs )
96
48
97
49
def _new_scope (self ):
98
50
self .env = self .env .new_child ()
@@ -154,14 +106,60 @@ def typecheck_input_buffer(self, proc_arg, kwargs):
154
106
f"but got shape { tuple (buf .shape )} "
155
107
)
156
108
109
+ def eval_proc (self , proc , kwargs ):
110
+ proc = ParallelAnalysis ().run (proc )
111
+ proc = PrecisionAnalysis ().run (proc ) # TODO: need this?
112
+ proc = WindowAnalysis ().apply_proc (proc )
113
+ proc = MemoryAnalysis ().run (proc ) # TODO: need this?
114
+
115
+ for a in proc .args :
116
+ if not str (a .name ) in kwargs :
117
+ raise TypeError (f"expected argument '{ a .name } ' to be supplied" )
118
+
119
+ if a .type is T .size :
120
+ if not is_pos_int (kwargs [str (a .name )]):
121
+ raise TypeError (
122
+ f"expected size '{ a .name } ' to have positive integer value"
123
+ )
124
+ self .env [a .name ] = kwargs [str (a .name )]
125
+ elif a .type is T .index :
126
+ if type (kwargs [str (a .name )]) is not int :
127
+ raise TypeError (
128
+ f"expected index variable '{ a .name } ' to be an integer"
129
+ )
130
+ self .env [a .name ] = kwargs [str (a .name )]
131
+ elif a .type is T .bool :
132
+ if type (kwargs [str (a .name )]) is not bool :
133
+ raise TypeError (f"expected bool variable '{ a .name } ' to be a bool" )
134
+ self .env [a .name ] = kwargs [str (a .name )]
135
+ elif a .type is T .stride :
136
+ if type (kwargs [str (a .name )]) is not int :
137
+ raise TypeError (
138
+ f"expected stride variable '{ a .name } ' to be an integer"
139
+ )
140
+ self .env [a .name ] = kwargs [str (a .name )]
141
+ else :
142
+ self .typecheck_input_buffer (a , kwargs )
143
+ self .env [a .name ] = kwargs [str (a .name )]
144
+
145
+ # evaluate preconditions
146
+ for pred in proc .preds :
147
+ if isinstance (pred , LoopIR .Const ):
148
+ continue
149
+ else :
150
+ assert self .eval_e (pred ), "precondition not satisfied"
151
+
152
+ # eval statements
153
+ self .eval_stmts (proc .body )
154
+
157
155
def eval_stmts (self , stmts ):
158
156
for s in stmts :
159
157
self .eval_s (s )
160
158
161
159
def eval_s (self , s ):
162
160
if isinstance (s , LoopIR .Pass ):
163
161
pass
164
-
162
+
165
163
elif isinstance (s , (LoopIR .Assign , LoopIR .Reduce )):
166
164
lbuf = self .env [s .name ]
167
165
if len (s .idx ) == 0 :
@@ -179,12 +177,14 @@ def eval_s(self, s):
179
177
elif isinstance (s , LoopIR .WriteConfig ):
180
178
nm = s .config .name ()
181
179
rhs = self .eval_e (s .rhs )
182
- ctxt [nm ][s .field ] = rhs
180
+ self . ctxt [nm ][s .field ] = rhs
183
181
184
182
elif isinstance (s , LoopIR .WindowStmt ):
185
183
# nm = rbuf[...]
186
184
assert s .name not in self .env , "WindowStmt should be a fresh assignment"
187
- assert isinstance (s .rhs , LoopIR .WindowExpr ), "WindowStmt rhs should be WindowExpr"
185
+ assert isinstance (
186
+ s .rhs , LoopIR .WindowExpr
187
+ ), "WindowStmt rhs should be WindowExpr"
188
188
self .env [s .name ] = self .eval_e (s .rhs )
189
189
190
190
elif isinstance (s , LoopIR .If ):
@@ -225,7 +225,9 @@ def eval_s(self, s):
225
225
argvals = [self .eval_e (a , call_arg = True ) for a in s .args ]
226
226
argnames = [str (a .name ) for a in s .f .args ]
227
227
kwargs = {nm : val for nm , val in zip (argnames , argvals )}
228
- Interpreter (s .f , kwargs , use_randomization = self .use_randomization )
228
+ self ._new_scope ()
229
+ self .eval_proc (s .f , kwargs )
230
+ self ._del_scope ()
229
231
230
232
else :
231
233
assert False , "bad statement case"
@@ -253,10 +255,14 @@ def stringify_w_access(a):
253
255
assert False , "bad w_access case"
254
256
255
257
# hack to handle interval indexes: LoopIR.Interval returns a string representing the interval
256
- idx = ("0" ,) if len (e .idx ) == 0 else tuple (stringify_w_access (a ) for a in e .idx )
258
+ idx = (
259
+ ("0" ,)
260
+ if len (e .idx ) == 0
261
+ else tuple (stringify_w_access (a ) for a in e .idx )
262
+ )
257
263
res = eval (f"buf[{ ',' .join (idx )} ]" )
258
264
return res
259
-
265
+
260
266
elif isinstance (e , LoopIR .Const ):
261
267
return e .val
262
268
@@ -268,9 +274,12 @@ def stringify_w_access(a):
268
274
return lhs - rhs
269
275
elif e .op == "*" :
270
276
return lhs * rhs
271
- elif e .op == "/" : # is this right?
272
- if isinstance (lhs , int ):
273
- return (lhs + rhs - 1 ) // rhs
277
+ elif e .op == "/" :
278
+ if isinstance (lhs , int ) and isinstance (rhs , int ):
279
+ # this is what was here before and without the rhs check
280
+ # counter example of why this is wrong -3 / 2 == -1 in C and 0 in this impl
281
+ # return (lhs + rhs - 1) // rhs
282
+ return int (lhs / rhs )
274
283
else :
275
284
return lhs / rhs
276
285
elif e .op == "%" :
@@ -293,9 +302,12 @@ def stringify_w_access(a):
293
302
elif isinstance (e , LoopIR .USub ):
294
303
return - self .eval_e (e .arg )
295
304
305
+ # BuiltIns don't go to the interpreter, they are just called (via call) like a proc
306
+ # TODO Discuss to make sure
296
307
elif isinstance (e , LoopIR .BuiltIn ):
297
- args = [self .eval_e (a ) for a in e .args ]
298
- return e .f .interpret (args )
308
+ assert False , "Not implemented"
309
+ # args = [self.eval_e(a) for a in e.args]
310
+ # return e.f.interpret(args)
299
311
300
312
elif isinstance (e , LoopIR .StrideExpr ):
301
313
buf = self .env [e .name ]
@@ -305,7 +317,7 @@ def stringify_w_access(a):
305
317
306
318
elif isinstance (e , LoopIR .ReadConfig ):
307
319
nm = e .config .name ()
308
- return ctxt [nm ][e .field ]
320
+ return self . ctxt [nm ][e .field ]
309
321
310
322
else :
311
323
print (e )
0 commit comments