Skip to content

Commit 9a2fecd

Browse files
committed
runtime
1 parent bd6ae4a commit 9a2fecd

File tree

2 files changed

+118
-127
lines changed

2 files changed

+118
-127
lines changed

src/exo/analysis.py

Lines changed: 118 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from .dataflow import (
1010
LoopIR_to_DataflowIR,
1111
ScalarPropagation,
12-
GetControlPredicates,
1312
GetValues,
1413
D,
1514
)
@@ -376,6 +375,71 @@ def lift_es(es):
376375
return [lift_e(e) for e in es]
377376

378377

378+
# --------------------------------------------------------------------------- #
379+
# Getting control flow on DataflowIR. Will be unnecessary when we
380+
# integrate control flow into abstract values.
381+
# --------------------------------------------------------------------------- #
382+
383+
384+
class GetControlPredicates(LoopIR_Do):
385+
def __init__(self, proc, stmts):
386+
self.proc = proc
387+
self.stmts = stmts
388+
self.preds = None
389+
self.done = False
390+
self.cur_preds = []
391+
392+
for a in self.proc.args:
393+
if isinstance(a.type, T.Size):
394+
size_pred = A.BinOp(
395+
"<",
396+
A.Const(0, T.int, null_srcinfo()),
397+
A.Var(a.name, T.size, a.srcinfo),
398+
T.bool,
399+
null_srcinfo(),
400+
)
401+
self.cur_preds.append(size_pred)
402+
self.do_t(a.type)
403+
404+
for pred in self.proc.preds:
405+
self.cur_preds.append(lift_e(pred))
406+
self.do_e(pred)
407+
408+
self.do_stmts(self.proc.body)
409+
410+
def do_s(self, s):
411+
if self.done:
412+
return
413+
414+
if s == self.stmts[0]:
415+
self.preds = AAnd(*self.cur_preds)
416+
self.done = True
417+
418+
styp = type(s)
419+
if styp is LoopIR.If:
420+
self.cur_preds.append(lift_e(s.cond))
421+
self.do_stmts(s.body)
422+
self.cur_preds.pop()
423+
424+
self.cur_preds.append(A.Not(lift_e(s.cond), T.int, null_srcinfo()))
425+
self.do_stmts(s.orelse)
426+
self.cur_preds.pop()
427+
428+
elif styp is LoopIR.For:
429+
a_iter = A.Var(s.iter, T.int, s.srcinfo)
430+
b1 = A.BinOp("<=", lift_e(s.lo), a_iter, T.bool, null_srcinfo())
431+
b2 = A.BinOp("<", a_iter, lift_e(s.hi), T.bool, null_srcinfo())
432+
cond = A.BinOp("and", b1, b2, T.bool, null_srcinfo())
433+
self.cur_preds.append(cond)
434+
self.do_stmts(s.body)
435+
self.cur_preds.pop()
436+
437+
super().do_s(s)
438+
439+
def result(self):
440+
return self.preds.simplify()
441+
442+
379443
# Produce a set of AExprs which occur as right-hand-sides
380444
# of config writes.
381445
def possible_config_writes(stmts):
@@ -1531,11 +1595,13 @@ def loop_globenv(i, lo_expr, hi_expr, body):
15311595

15321596

15331597
def Check_ReorderStmts(proc, s1, s2):
1534-
datair, stmts = LoopIR_to_DataflowIR(proc, [s1, s2]).result()
1598+
# datair, stmts = LoopIR_to_DataflowIR(proc, [s1, s2]).result()
1599+
1600+
# print("here in ReorderStmts")
15351601

1536-
assert len(stmts) == 2
1602+
assert isinstance(s1, LoopIR.stmt) and isinstance(s2, LoopIR.stmt)
15371603

1538-
p = GetControlPredicates(datair, stmts).result()
1604+
p = GetControlPredicates(proc, [s1, s2]).result()
15391605

15401606
slv = SMTSolver(verbose=False)
15411607
slv.push()
@@ -1554,11 +1620,13 @@ def Check_ReorderStmts(proc, s1, s2):
15541620

15551621

15561622
def Check_ReorderLoops(proc, s):
1557-
datair, stmts = LoopIR_to_DataflowIR(proc, [s]).result()
1623+
# datair, stmts = LoopIR_to_DataflowIR(proc, [s]).result()
15581624

1559-
assert len(stmts) == 1
1625+
# print("here in ReorderLoops")
15601626

1561-
p = GetControlPredicates(datair, stmts).result()
1627+
assert isinstance(s, LoopIR.For)
1628+
1629+
p = GetControlPredicates(proc, [s]).result()
15621630

15631631
slv = SMTSolver(verbose=False)
15641632
slv.push()
@@ -1632,11 +1700,13 @@ def bds(x, lo, hi):
16321700
# /\ ( forall i,i'. May(InBound(i,i',e) /\ i < i') => Commutes(a1', a1) )
16331701
#
16341702
def Check_ParallelizeLoop(proc, s):
1635-
datair, stmts = LoopIR_to_DataflowIR(proc, [s]).result()
1703+
# datair, stmts = LoopIR_to_DataflowIR(proc, [s]).result()
1704+
1705+
# print("Check_ParallelizeLoop")
16361706

1637-
assert len(stmts) == 1
1707+
assert isinstance(s, LoopIR.For)
16381708

1639-
p = GetControlPredicates(datair, stmts).result()
1709+
p = GetControlPredicates(proc, [s]).result()
16401710

16411711
slv = SMTSolver(verbose=False)
16421712
slv.push()
@@ -1688,9 +1758,11 @@ def bds(x, lo, hi):
16881758
#
16891759
def Check_FissionLoop(proc, loop, stmts1, stmts2, no_loop_var_1=False):
16901760

1691-
datair, d_loop = LoopIR_to_DataflowIR(proc, [loop]).result()
1761+
# print("Check_FissionLoop")
16921762

1693-
p = GetControlPredicates(datair, d_loop).result()
1763+
# datair, d_loop = LoopIR_to_DataflowIR(proc, [loop]).result()
1764+
1765+
p = GetControlPredicates(proc, [loop]).result()
16941766

16951767
slv = SMTSolver(verbose=False)
16961768
slv.push()
@@ -1774,9 +1846,9 @@ def lift_dexpr(e, key=None):
17741846
def Check_DeleteConfigWrite(proc, stmts):
17751847
assert len(stmts) > 0
17761848

1777-
ir1, d_stmts = LoopIR_to_DataflowIR(proc, stmts).result()
1778-
p = GetControlPredicates(ir1, d_stmts).result()
1849+
# print("here in DeleteConfigWrite")
17791850

1851+
p = GetControlPredicates(proc, stmts).result()
17801852
slv = SMTSolver(verbose=False)
17811853
slv.push()
17821854
slv.assume(AMay(p))
@@ -1801,6 +1873,7 @@ def Check_DeleteConfigWrite(proc, stmts):
18011873
)
18021874

18031875
# Below are the actual checks
1876+
ir1, d_stmts = LoopIR_to_DataflowIR(proc, stmts).result()
18041877

18051878
ScalarPropagation(ir1)
18061879

@@ -1869,6 +1942,8 @@ def Check_ExtendEqv(proc1, proc2, stmts1, stmts2, cfg_mod):
18691942
assert len(stmts1) == 1
18701943
assert len(stmts2) == 1
18711944

1945+
# print("here in Check_ExtendEqv")
1946+
18721947
slv = SMTSolver(verbose=False)
18731948
slv.push()
18741949

@@ -1928,16 +2003,18 @@ def make_point(key):
19282003

19292004

19302005
def Check_ExprEqvInContext(proc, expr0, stmts0, expr1, stmts1=None):
2006+
2007+
# print("Check_ExprEqvInContext")
19312008
assert len(stmts0) > 0
19322009
stmts1 = stmts1 or stmts0
19332010

1934-
len_0 = len(stmts0)
1935-
datair, d_stmts = LoopIR_to_DataflowIR(proc, stmts0 + stmts1).result()
1936-
d_stmts0 = d_stmts[0:len_0]
1937-
d_stmts1 = d_stmts[len_0:]
2011+
# len_0 = len(stmts0)
2012+
# datair, d_stmts = LoopIR_to_DataflowIR(proc, stmts0 + stmts1).result()
2013+
# d_stmts0 = d_stmts[0:len_0]
2014+
# d_stmts1 = d_stmts[len_0:]
19382015

1939-
p0 = GetControlPredicates(datair, d_stmts0).result()
1940-
p1 = GetControlPredicates(datair, d_stmts1).result()
2016+
p0 = GetControlPredicates(proc, stmts0).result()
2017+
p1 = GetControlPredicates(proc, stmts1).result()
19412018

19422019
slv = SMTSolver(verbose=False)
19432020
slv.push()
@@ -1954,11 +2031,13 @@ def Check_ExprEqvInContext(proc, expr0, stmts0, expr1, stmts1=None):
19542031

19552032

19562033
def Check_BufferReduceOnly(proc, stmts, buf, ndim):
2034+
2035+
# print("Check_BufferReduceOnly")
19572036
assert len(stmts) > 0
19582037

1959-
datair, d_stmts = LoopIR_to_DataflowIR(proc, stmts).result()
2038+
# datair, d_stmts = LoopIR_to_DataflowIR(proc, stmts).result()
19602039

1961-
p = GetControlPredicates(datair, d_stmts).result()
2040+
p = GetControlPredicates(proc, stmts).result()
19622041

19632042
slv = SMTSolver(verbose=False)
19642043
slv.push()
@@ -1988,13 +2067,15 @@ def Check_Access_In_Window(proc, access_cursor, w_exprs, block_cursor):
19882067
block_cursor is the context in which to interpret the access in.
19892068
"""
19902069

2070+
# print("Check_Access_In_Window")
2071+
19912072
access = access_cursor._node
19922073
block = [x._node for x in block_cursor]
19932074
idxs = access.idx
19942075
assert len(idxs) == len(w_exprs)
19952076

1996-
datair, d_stmts = LoopIR_to_DataflowIR(proc, block).result()
1997-
p = GetControlPredicates(datair, d_stmts).result()
2077+
# datair, d_stmts = LoopIR_to_DataflowIR(proc, block).result()
2078+
p = GetControlPredicates(proc, block).result()
19982079

19992080
slv = SMTSolver(verbose=False)
20002081
slv.push()
@@ -2067,9 +2148,10 @@ def Check_Bounds(proc, alloc_stmt, block):
20672148
if len(block) == 0:
20682149
return
20692150

2070-
datair, stmts = LoopIR_to_DataflowIR(proc, block).result()
2151+
# print("Check_Bounds")
2152+
# datair, stmts = LoopIR_to_DataflowIR(proc, block).result()
20712153

2072-
p = GetControlPredicates(datair, stmts).result()
2154+
p = GetControlPredicates(proc, block).result()
20732155

20742156
slv = SMTSolver(verbose=False)
20752157
slv.push()
@@ -2105,6 +2187,8 @@ def Check_Bounds(proc, alloc_stmt, block):
21052187

21062188

21072189
def Check_IsDeadAfter(proc, stmts, bufname, ndim):
2190+
2191+
# print("Check_IsDeadAfter")
21082192
assert len(stmts) > 0
21092193

21102194
ap = PostEnv(proc, stmts).get_posteffs()
@@ -2126,11 +2210,13 @@ def Check_IsDeadAfter(proc, stmts, bufname, ndim):
21262210

21272211

21282212
def Check_IsIdempotent(proc, stmts):
2213+
2214+
# print("Check_IsIdempotent")
21292215
assert len(stmts) > 0
21302216

2131-
datair, d_stmts = LoopIR_to_DataflowIR(proc, stmts).result()
2217+
# datair, d_stmts = LoopIR_to_DataflowIR(proc, stmts).result()
21322218

2133-
p = GetControlPredicates(datair, d_stmts).result()
2219+
p = GetControlPredicates(proc, stmts).result()
21342220

21352221
slv = SMTSolver(verbose=False)
21362222
slv.push()
@@ -2144,10 +2230,11 @@ def Check_IsIdempotent(proc, stmts):
21442230

21452231

21462232
def Check_ExprBound(proc, stmts, expr, op, value, exception=True):
2233+
# print("Check_ExprBound")
21472234
assert len(stmts) > 0
21482235

2149-
datair, d_stmts = LoopIR_to_DataflowIR(proc, stmts).result()
2150-
p = GetControlPredicates(datair, d_stmts).result()
2236+
# datair, d_stmts = LoopIR_to_DataflowIR(proc, stmts).result()
2237+
p = GetControlPredicates(proc, stmts).result()
21512238

21522239
# TODO: Check_ExprBound does not depend on configuration states so this can be skipped, but more fundamentally running abstract interpretation this many times is simply too slow.
21532240
# ScalarPropagation(datair)
@@ -2335,5 +2422,6 @@ def do_s(self, s):
23352422

23362423

23372424
def Check_Aliasing(proc):
2425+
# print("Check_Aliasing")
23382426
helper = _Check_Aliasing_Helper(proc)
23392427
# that's it

src/exo/dataflow.py

Lines changed: 0 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -829,100 +829,3 @@ def abs_builtin(self, builtin, args):
829829

830830
# TODO: write a short circuit for select builtin
831831
return D.Const(builtin.interpret(vargs), args[0].typ)
832-
833-
834-
# --------------------------------------------------------------------------- #
835-
# Getting control flow on DataflowIR. Will be unnecessary when we
836-
# integrate control flow into abstract values.
837-
# --------------------------------------------------------------------------- #
838-
839-
840-
def lift_dataflow(e):
841-
if e.type.is_indexable() or e.type.is_stridable() or e.type == T.bool:
842-
if isinstance(e, DataflowIR.Read):
843-
assert len(e.idx) == 0
844-
return A.Var(e.name, e.type, e.srcinfo)
845-
elif isinstance(e, DataflowIR.Const):
846-
return A.Const(e.val, e.type, e.srcinfo)
847-
elif isinstance(e, DataflowIR.BinOp):
848-
return A.BinOp(
849-
e.op, lift_dataflow(e.lhs), lift_dataflow(e.rhs), e.type, e.srcinfo
850-
)
851-
elif isinstance(e, DataflowIR.USub):
852-
return A.USub(lift_dataflow(e.arg), e.type, e.srcinfo)
853-
elif isinstance(e, DataflowIR.StrideExpr):
854-
return A.Stride(e.name, e.dim, e.type, e.srcinfo)
855-
elif isinstance(e, DataflowIR.ReadConfig):
856-
return A.Var(e.config_field, e.type, e.srcinfo)
857-
else:
858-
f"bad case: {type(e)}"
859-
else:
860-
assert e.type.is_numeric()
861-
if e.type.is_real_scalar():
862-
if isinstance(e, DataflowIR.Const):
863-
return A.Const(e.val, e.type, e.srcinfo)
864-
elif isinstance(e, DataflowIR.Read):
865-
return A.ConstSym(e.name, e.type, e.srcinfo)
866-
elif isinstance(e, DataflowIR.ReadConfig):
867-
return A.Var(e.config_field, e.type, e.srcinfo)
868-
869-
return A.Unk(T.err, e.srcinfo)
870-
871-
872-
class GetControlPredicates(DataflowIR_Do):
873-
def __init__(self, datair, stmts):
874-
self.datair = datair
875-
self.stmts = stmts
876-
self.preds = None
877-
self.done = False
878-
self.cur_preds = []
879-
880-
for a in self.datair.args:
881-
if isinstance(a.type, T.Size):
882-
size_pred = A.BinOp(
883-
"<",
884-
A.Const(0, T.int, null_srcinfo()),
885-
A.Var(a.name, T.size, a.srcinfo),
886-
T.bool,
887-
null_srcinfo(),
888-
)
889-
self.cur_preds.append(size_pred)
890-
self.do_t(a.type)
891-
892-
for pred in self.datair.preds:
893-
self.cur_preds.append(lift_dataflow(pred))
894-
self.do_e(pred)
895-
896-
self.do_stmts(self.datair.body.stmts)
897-
898-
def do_s(self, s):
899-
if self.done:
900-
return
901-
902-
if s == self.stmts[0]:
903-
self.preds = AAnd(*self.cur_preds)
904-
self.done = True
905-
906-
styp = type(s)
907-
if styp is DataflowIR.If:
908-
self.cur_preds.append(lift_dataflow(s.cond))
909-
self.do_stmts(s.body.stmts)
910-
self.cur_preds.pop()
911-
912-
self.cur_preds.append(A.Not(lift_dataflow(s.cond), T.int, null_srcinfo()))
913-
self.do_stmts(s.orelse.stmts)
914-
self.cur_preds.pop()
915-
916-
elif styp is DataflowIR.For:
917-
a_iter = A.Var(s.iter, T.int, s.srcinfo)
918-
b1 = A.BinOp("<=", lift_dataflow(s.lo), a_iter, T.bool, null_srcinfo())
919-
b2 = A.BinOp("<", a_iter, lift_dataflow(s.hi), T.bool, null_srcinfo())
920-
cond = A.BinOp("and", b1, b2, T.bool, null_srcinfo())
921-
self.cur_preds.append(cond)
922-
self.do_stmts(s.body.stmts)
923-
self.cur_preds.pop()
924-
925-
super().do_s(s)
926-
927-
def result(self):
928-
return self.preds.simplify()

0 commit comments

Comments
 (0)