Skip to content

Commit dd8784c

Browse files
committed
refactor interpreter; fix C division bug; fix some of the failing tests; add 'built-in' test
1 parent 27ae79e commit dd8784c

File tree

2 files changed

+240
-154
lines changed

2 files changed

+240
-154
lines changed

src/exo/LoopIR_interpreter.py

Lines changed: 76 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -35,64 +35,16 @@ def run_interpreter(proc, kwargs):
3535
Interpreter(proc, kwargs)
3636

3737

38-
# context is global
39-
ctxt = defaultdict(dict)
40-
4138
class Interpreter:
4239
def __init__(self, proc, kwargs, use_randomization=False):
43-
assert isinstance(proc, LoopIR.proc)
44-
45-
proc = ParallelAnalysis().run(proc)
46-
proc = PrecisionAnalysis().run(proc) # TODO: need this?
47-
proc = WindowAnalysis().apply_proc(proc)
48-
proc = MemoryAnalysis().run(proc) # TODO: need this?
40+
if not isinstance(proc, LoopIR.proc):
41+
raise TypeError(f"Expected {proc.name} to be of type proc")
4942

50-
self.proc = proc
5143
self.env = ChainMap()
5244
self.use_randomization = use_randomization
45+
self.ctxt = defaultdict(dict)
5346

54-
# type check args
55-
for a in proc.args:
56-
if not str(a.name) in kwargs:
57-
raise TypeError(f"expected argument '{a.name}' to be supplied")
58-
59-
if a.type is T.size:
60-
if not is_pos_int(kwargs[str(a.name)]):
61-
raise TypeError(
62-
f"expected size '{a.name}' to have positive integer value"
63-
)
64-
self.env[a.name] = kwargs[str(a.name)]
65-
elif a.type is T.index:
66-
if type(kwargs[str(a.name)]) is not int:
67-
raise TypeError(
68-
f"expected index variable '{a.name}' to be an integer"
69-
)
70-
self.env[a.name] = kwargs[str(a.name)]
71-
elif a.type is T.bool:
72-
if type(kwargs[str(a.name)]) is not bool:
73-
raise TypeError(f"expected bool variable '{a.name}' to be a bool")
74-
self.env[a.name] = kwargs[str(a.name)]
75-
elif a.type is T.stride:
76-
if type(kwargs[str(a.name)]) is not int:
77-
raise TypeError(
78-
f"expected stride variable '{a.name}' to be an integer"
79-
)
80-
self.env[a.name] = kwargs[str(a.name)]
81-
else:
82-
self.typecheck_input_buffer(a, kwargs)
83-
self.env[a.name] = kwargs[str(a.name)]
84-
85-
# evaluate preconditions
86-
for pred in proc.preds:
87-
if isinstance(pred, LoopIR.Const):
88-
continue
89-
else:
90-
assert self.eval_e(pred), "precondition not satisfied"
91-
92-
# eval statements
93-
self.env = self.env.new_child()
94-
self.eval_stmts(proc.body)
95-
self.env = self.env.parents
47+
self.eval_proc(proc, kwargs)
9648

9749
def _new_scope(self):
9850
self.env = self.env.new_child()
@@ -154,14 +106,60 @@ def typecheck_input_buffer(self, proc_arg, kwargs):
154106
f"but got shape {tuple(buf.shape)}"
155107
)
156108

109+
def eval_proc(self, proc, kwargs):
110+
proc = ParallelAnalysis().run(proc)
111+
proc = PrecisionAnalysis().run(proc) # TODO: need this?
112+
proc = WindowAnalysis().apply_proc(proc)
113+
proc = MemoryAnalysis().run(proc) # TODO: need this?
114+
115+
for a in proc.args:
116+
if not str(a.name) in kwargs:
117+
raise TypeError(f"expected argument '{a.name}' to be supplied")
118+
119+
if a.type is T.size:
120+
if not is_pos_int(kwargs[str(a.name)]):
121+
raise TypeError(
122+
f"expected size '{a.name}' to have positive integer value"
123+
)
124+
self.env[a.name] = kwargs[str(a.name)]
125+
elif a.type is T.index:
126+
if type(kwargs[str(a.name)]) is not int:
127+
raise TypeError(
128+
f"expected index variable '{a.name}' to be an integer"
129+
)
130+
self.env[a.name] = kwargs[str(a.name)]
131+
elif a.type is T.bool:
132+
if type(kwargs[str(a.name)]) is not bool:
133+
raise TypeError(f"expected bool variable '{a.name}' to be a bool")
134+
self.env[a.name] = kwargs[str(a.name)]
135+
elif a.type is T.stride:
136+
if type(kwargs[str(a.name)]) is not int:
137+
raise TypeError(
138+
f"expected stride variable '{a.name}' to be an integer"
139+
)
140+
self.env[a.name] = kwargs[str(a.name)]
141+
else:
142+
self.typecheck_input_buffer(a, kwargs)
143+
self.env[a.name] = kwargs[str(a.name)]
144+
145+
# evaluate preconditions
146+
for pred in proc.preds:
147+
if isinstance(pred, LoopIR.Const):
148+
continue
149+
else:
150+
assert self.eval_e(pred), "precondition not satisfied"
151+
152+
# eval statements
153+
self.eval_stmts(proc.body)
154+
157155
def eval_stmts(self, stmts):
158156
for s in stmts:
159157
self.eval_s(s)
160158

161159
def eval_s(self, s):
162160
if isinstance(s, LoopIR.Pass):
163161
pass
164-
162+
165163
elif isinstance(s, (LoopIR.Assign, LoopIR.Reduce)):
166164
lbuf = self.env[s.name]
167165
if len(s.idx) == 0:
@@ -179,12 +177,14 @@ def eval_s(self, s):
179177
elif isinstance(s, LoopIR.WriteConfig):
180178
nm = s.config.name()
181179
rhs = self.eval_e(s.rhs)
182-
ctxt[nm][s.field] = rhs
180+
self.ctxt[nm][s.field] = rhs
183181

184182
elif isinstance(s, LoopIR.WindowStmt):
185183
# nm = rbuf[...]
186184
assert s.name not in self.env, "WindowStmt should be a fresh assignment"
187-
assert isinstance(s.rhs, LoopIR.WindowExpr), "WindowStmt rhs should be WindowExpr"
185+
assert isinstance(
186+
s.rhs, LoopIR.WindowExpr
187+
), "WindowStmt rhs should be WindowExpr"
188188
self.env[s.name] = self.eval_e(s.rhs)
189189

190190
elif isinstance(s, LoopIR.If):
@@ -225,7 +225,9 @@ def eval_s(self, s):
225225
argvals = [self.eval_e(a, call_arg=True) for a in s.args]
226226
argnames = [str(a.name) for a in s.f.args]
227227
kwargs = {nm: val for nm, val in zip(argnames, argvals)}
228-
Interpreter(s.f, kwargs, use_randomization=self.use_randomization)
228+
self._new_scope()
229+
self.eval_proc(s.f, kwargs)
230+
self._del_scope()
229231

230232
else:
231233
assert False, "bad statement case"
@@ -253,10 +255,14 @@ def stringify_w_access(a):
253255
assert False, "bad w_access case"
254256

255257
# hack to handle interval indexes: LoopIR.Interval returns a string representing the interval
256-
idx = ("0",) if len(e.idx) == 0 else tuple(stringify_w_access(a) for a in e.idx)
258+
idx = (
259+
("0",)
260+
if len(e.idx) == 0
261+
else tuple(stringify_w_access(a) for a in e.idx)
262+
)
257263
res = eval(f"buf[{','.join(idx)}]")
258264
return res
259-
265+
260266
elif isinstance(e, LoopIR.Const):
261267
return e.val
262268

@@ -268,9 +274,12 @@ def stringify_w_access(a):
268274
return lhs - rhs
269275
elif e.op == "*":
270276
return lhs * rhs
271-
elif e.op == "/": # is this right?
272-
if isinstance(lhs, int):
273-
return (lhs + rhs - 1) // rhs
277+
elif e.op == "/":
278+
if isinstance(lhs, int) and isinstance(rhs, int):
279+
# this is what was here before and without the rhs check
280+
# counter example of why this is wrong -3 / 2 == -1 in C and 0 in this impl
281+
# return (lhs + rhs - 1) // rhs
282+
return int(lhs / rhs)
274283
else:
275284
return lhs / rhs
276285
elif e.op == "%":
@@ -293,9 +302,12 @@ def stringify_w_access(a):
293302
elif isinstance(e, LoopIR.USub):
294303
return -self.eval_e(e.arg)
295304

305+
# BuiltIns don't go to the interpreter, they are just called (via call) like a proc
306+
# TODO Discuss to make sure
296307
elif isinstance(e, LoopIR.BuiltIn):
297-
args = [self.eval_e(a) for a in e.args]
298-
return e.f.interpret(args)
308+
assert False, "Not implemented"
309+
# args = [self.eval_e(a) for a in e.args]
310+
# return e.f.interpret(args)
299311

300312
elif isinstance(e, LoopIR.StrideExpr):
301313
buf = self.env[e.name]
@@ -305,7 +317,7 @@ def stringify_w_access(a):
305317

306318
elif isinstance(e, LoopIR.ReadConfig):
307319
nm = e.config.name()
308-
return ctxt[nm][e.field]
320+
return self.ctxt[nm][e.field]
309321

310322
else:
311323
print(e)

0 commit comments

Comments
 (0)