9
9
from .dataflow import (
10
10
LoopIR_to_DataflowIR ,
11
11
ScalarPropagation ,
12
- GetControlPredicates ,
13
12
GetValues ,
14
13
D ,
15
14
)
@@ -376,6 +375,71 @@ def lift_es(es):
376
375
return [lift_e (e ) for e in es ]
377
376
378
377
378
+ # --------------------------------------------------------------------------- #
379
+ # Getting control flow on DataflowIR. Will be unnecessary when we
380
+ # integrate control flow into abstract values.
381
+ # --------------------------------------------------------------------------- #
382
+
383
+
384
+ class GetControlPredicates (LoopIR_Do ):
385
+ def __init__ (self , proc , stmts ):
386
+ self .proc = proc
387
+ self .stmts = stmts
388
+ self .preds = None
389
+ self .done = False
390
+ self .cur_preds = []
391
+
392
+ for a in self .proc .args :
393
+ if isinstance (a .type , T .Size ):
394
+ size_pred = A .BinOp (
395
+ "<" ,
396
+ A .Const (0 , T .int , null_srcinfo ()),
397
+ A .Var (a .name , T .size , a .srcinfo ),
398
+ T .bool ,
399
+ null_srcinfo (),
400
+ )
401
+ self .cur_preds .append (size_pred )
402
+ self .do_t (a .type )
403
+
404
+ for pred in self .proc .preds :
405
+ self .cur_preds .append (lift_e (pred ))
406
+ self .do_e (pred )
407
+
408
+ self .do_stmts (self .proc .body )
409
+
410
+ def do_s (self , s ):
411
+ if self .done :
412
+ return
413
+
414
+ if s == self .stmts [0 ]:
415
+ self .preds = AAnd (* self .cur_preds )
416
+ self .done = True
417
+
418
+ styp = type (s )
419
+ if styp is LoopIR .If :
420
+ self .cur_preds .append (lift_e (s .cond ))
421
+ self .do_stmts (s .body )
422
+ self .cur_preds .pop ()
423
+
424
+ self .cur_preds .append (A .Not (lift_e (s .cond ), T .int , null_srcinfo ()))
425
+ self .do_stmts (s .orelse )
426
+ self .cur_preds .pop ()
427
+
428
+ elif styp is LoopIR .For :
429
+ a_iter = A .Var (s .iter , T .int , s .srcinfo )
430
+ b1 = A .BinOp ("<=" , lift_e (s .lo ), a_iter , T .bool , null_srcinfo ())
431
+ b2 = A .BinOp ("<" , a_iter , lift_e (s .hi ), T .bool , null_srcinfo ())
432
+ cond = A .BinOp ("and" , b1 , b2 , T .bool , null_srcinfo ())
433
+ self .cur_preds .append (cond )
434
+ self .do_stmts (s .body )
435
+ self .cur_preds .pop ()
436
+
437
+ super ().do_s (s )
438
+
439
+ def result (self ):
440
+ return self .preds .simplify ()
441
+
442
+
379
443
# Produce a set of AExprs which occur as right-hand-sides
380
444
# of config writes.
381
445
def possible_config_writes (stmts ):
@@ -1531,11 +1595,13 @@ def loop_globenv(i, lo_expr, hi_expr, body):
1531
1595
1532
1596
1533
1597
def Check_ReorderStmts (proc , s1 , s2 ):
1534
- datair , stmts = LoopIR_to_DataflowIR (proc , [s1 , s2 ]).result ()
1598
+ # datair, stmts = LoopIR_to_DataflowIR(proc, [s1, s2]).result()
1599
+
1600
+ # print("here in ReorderStmts")
1535
1601
1536
- assert len ( stmts ) == 2
1602
+ assert isinstance ( s1 , LoopIR . stmt ) and isinstance ( s2 , LoopIR . stmt )
1537
1603
1538
- p = GetControlPredicates (datair , stmts ).result ()
1604
+ p = GetControlPredicates (proc , [ s1 , s2 ] ).result ()
1539
1605
1540
1606
slv = SMTSolver (verbose = False )
1541
1607
slv .push ()
@@ -1554,11 +1620,13 @@ def Check_ReorderStmts(proc, s1, s2):
1554
1620
1555
1621
1556
1622
def Check_ReorderLoops (proc , s ):
1557
- datair , stmts = LoopIR_to_DataflowIR (proc , [s ]).result ()
1623
+ # datair, stmts = LoopIR_to_DataflowIR(proc, [s]).result()
1558
1624
1559
- assert len ( stmts ) == 1
1625
+ # print("here in ReorderLoops")
1560
1626
1561
- p = GetControlPredicates (datair , stmts ).result ()
1627
+ assert isinstance (s , LoopIR .For )
1628
+
1629
+ p = GetControlPredicates (proc , [s ]).result ()
1562
1630
1563
1631
slv = SMTSolver (verbose = False )
1564
1632
slv .push ()
@@ -1632,11 +1700,13 @@ def bds(x, lo, hi):
1632
1700
# /\ ( forall i,i'. May(InBound(i,i',e) /\ i < i') => Commutes(a1', a1) )
1633
1701
#
1634
1702
def Check_ParallelizeLoop (proc , s ):
1635
- datair , stmts = LoopIR_to_DataflowIR (proc , [s ]).result ()
1703
+ # datair, stmts = LoopIR_to_DataflowIR(proc, [s]).result()
1704
+
1705
+ # print("Check_ParallelizeLoop")
1636
1706
1637
- assert len ( stmts ) == 1
1707
+ assert isinstance ( s , LoopIR . For )
1638
1708
1639
- p = GetControlPredicates (datair , stmts ).result ()
1709
+ p = GetControlPredicates (proc , [ s ] ).result ()
1640
1710
1641
1711
slv = SMTSolver (verbose = False )
1642
1712
slv .push ()
@@ -1688,9 +1758,11 @@ def bds(x, lo, hi):
1688
1758
#
1689
1759
def Check_FissionLoop (proc , loop , stmts1 , stmts2 , no_loop_var_1 = False ):
1690
1760
1691
- datair , d_loop = LoopIR_to_DataflowIR ( proc , [ loop ]). result ( )
1761
+ # print("Check_FissionLoop" )
1692
1762
1693
- p = GetControlPredicates (datair , d_loop ).result ()
1763
+ # datair, d_loop = LoopIR_to_DataflowIR(proc, [loop]).result()
1764
+
1765
+ p = GetControlPredicates (proc , [loop ]).result ()
1694
1766
1695
1767
slv = SMTSolver (verbose = False )
1696
1768
slv .push ()
@@ -1774,9 +1846,9 @@ def lift_dexpr(e, key=None):
1774
1846
def Check_DeleteConfigWrite (proc , stmts ):
1775
1847
assert len (stmts ) > 0
1776
1848
1777
- ir1 , d_stmts = LoopIR_to_DataflowIR (proc , stmts ).result ()
1778
- p = GetControlPredicates (ir1 , d_stmts ).result ()
1849
+ print ("here in DeleteConfigWrite" )
1779
1850
1851
+ p = GetControlPredicates (proc , stmts ).result ()
1780
1852
slv = SMTSolver (verbose = False )
1781
1853
slv .push ()
1782
1854
slv .assume (AMay (p ))
@@ -1801,6 +1873,7 @@ def Check_DeleteConfigWrite(proc, stmts):
1801
1873
)
1802
1874
1803
1875
# Below are the actual checks
1876
+ ir1 , d_stmts = LoopIR_to_DataflowIR (proc , stmts ).result ()
1804
1877
1805
1878
ScalarPropagation (ir1 )
1806
1879
@@ -1869,6 +1942,8 @@ def Check_ExtendEqv(proc1, proc2, stmts1, stmts2, cfg_mod):
1869
1942
assert len (stmts1 ) == 1
1870
1943
assert len (stmts2 ) == 1
1871
1944
1945
+ print ("here in Check_ExtendEqv" )
1946
+
1872
1947
slv = SMTSolver (verbose = False )
1873
1948
slv .push ()
1874
1949
@@ -1928,16 +2003,18 @@ def make_point(key):
1928
2003
1929
2004
1930
2005
def Check_ExprEqvInContext (proc , expr0 , stmts0 , expr1 , stmts1 = None ):
2006
+
2007
+ # print("Check_ExprEqvInContext")
1931
2008
assert len (stmts0 ) > 0
1932
2009
stmts1 = stmts1 or stmts0
1933
2010
1934
- len_0 = len (stmts0 )
1935
- datair , d_stmts = LoopIR_to_DataflowIR (proc , stmts0 + stmts1 ).result ()
1936
- d_stmts0 = d_stmts [0 :len_0 ]
1937
- d_stmts1 = d_stmts [len_0 :]
2011
+ # len_0 = len(stmts0)
2012
+ # datair, d_stmts = LoopIR_to_DataflowIR(proc, stmts0 + stmts1).result()
2013
+ # d_stmts0 = d_stmts[0:len_0]
2014
+ # d_stmts1 = d_stmts[len_0:]
1938
2015
1939
- p0 = GetControlPredicates (datair , d_stmts0 ).result ()
1940
- p1 = GetControlPredicates (datair , d_stmts1 ).result ()
2016
+ p0 = GetControlPredicates (proc , stmts0 ).result ()
2017
+ p1 = GetControlPredicates (proc , stmts1 ).result ()
1941
2018
1942
2019
slv = SMTSolver (verbose = False )
1943
2020
slv .push ()
@@ -1954,11 +2031,13 @@ def Check_ExprEqvInContext(proc, expr0, stmts0, expr1, stmts1=None):
1954
2031
1955
2032
1956
2033
def Check_BufferReduceOnly (proc , stmts , buf , ndim ):
2034
+
2035
+ print ("Check_BufferReduceOnly" )
1957
2036
assert len (stmts ) > 0
1958
2037
1959
- datair , d_stmts = LoopIR_to_DataflowIR (proc , stmts ).result ()
2038
+ # datair, d_stmts = LoopIR_to_DataflowIR(proc, stmts).result()
1960
2039
1961
- p = GetControlPredicates (datair , d_stmts ).result ()
2040
+ p = GetControlPredicates (proc , stmts ).result ()
1962
2041
1963
2042
slv = SMTSolver (verbose = False )
1964
2043
slv .push ()
@@ -1988,13 +2067,15 @@ def Check_Access_In_Window(proc, access_cursor, w_exprs, block_cursor):
1988
2067
block_cursor is the context in which to interpret the access in.
1989
2068
"""
1990
2069
2070
+ # print("Check_Access_In_Window")
2071
+
1991
2072
access = access_cursor ._node
1992
2073
block = [x ._node for x in block_cursor ]
1993
2074
idxs = access .idx
1994
2075
assert len (idxs ) == len (w_exprs )
1995
2076
1996
- datair , d_stmts = LoopIR_to_DataflowIR (proc , block ).result ()
1997
- p = GetControlPredicates (datair , d_stmts ).result ()
2077
+ # datair, d_stmts = LoopIR_to_DataflowIR(proc, block).result()
2078
+ p = GetControlPredicates (proc , block ).result ()
1998
2079
1999
2080
slv = SMTSolver (verbose = False )
2000
2081
slv .push ()
@@ -2067,9 +2148,10 @@ def Check_Bounds(proc, alloc_stmt, block):
2067
2148
if len (block ) == 0 :
2068
2149
return
2069
2150
2070
- datair , stmts = LoopIR_to_DataflowIR (proc , block ).result ()
2151
+ # print("Check_Bounds")
2152
+ # datair, stmts = LoopIR_to_DataflowIR(proc, block).result()
2071
2153
2072
- p = GetControlPredicates (datair , stmts ).result ()
2154
+ p = GetControlPredicates (proc , block ).result ()
2073
2155
2074
2156
slv = SMTSolver (verbose = False )
2075
2157
slv .push ()
@@ -2105,6 +2187,8 @@ def Check_Bounds(proc, alloc_stmt, block):
2105
2187
2106
2188
2107
2189
def Check_IsDeadAfter (proc , stmts , bufname , ndim ):
2190
+
2191
+ print ("Check_IsDeadAfter" )
2108
2192
assert len (stmts ) > 0
2109
2193
2110
2194
ap = PostEnv (proc , stmts ).get_posteffs ()
@@ -2126,11 +2210,13 @@ def Check_IsDeadAfter(proc, stmts, bufname, ndim):
2126
2210
2127
2211
2128
2212
def Check_IsIdempotent (proc , stmts ):
2213
+
2214
+ print ("Check_IsIdempotent" )
2129
2215
assert len (stmts ) > 0
2130
2216
2131
- datair , d_stmts = LoopIR_to_DataflowIR (proc , stmts ).result ()
2217
+ # datair, d_stmts = LoopIR_to_DataflowIR(proc, stmts).result()
2132
2218
2133
- p = GetControlPredicates (datair , d_stmts ).result ()
2219
+ p = GetControlPredicates (proc , stmts ).result ()
2134
2220
2135
2221
slv = SMTSolver (verbose = False )
2136
2222
slv .push ()
@@ -2144,10 +2230,11 @@ def Check_IsIdempotent(proc, stmts):
2144
2230
2145
2231
2146
2232
def Check_ExprBound (proc , stmts , expr , op , value , exception = True ):
2233
+ print ("Check_ExprBound" )
2147
2234
assert len (stmts ) > 0
2148
2235
2149
- datair , d_stmts = LoopIR_to_DataflowIR (proc , stmts ).result ()
2150
- p = GetControlPredicates (datair , d_stmts ).result ()
2236
+ # datair, d_stmts = LoopIR_to_DataflowIR(proc, stmts).result()
2237
+ p = GetControlPredicates (proc , stmts ).result ()
2151
2238
2152
2239
# TODO: Check_ExprBound does not depend on configuration states so this can be skipped, but more fundamentally running abstract interpretation this many times is simply too slow.
2153
2240
# ScalarPropagation(datair)
@@ -2335,5 +2422,6 @@ def do_s(self, s):
2335
2422
2336
2423
2337
2424
def Check_Aliasing (proc ):
2425
+ print ("Check_Aliasing" )
2338
2426
helper = _Check_Aliasing_Helper (proc )
2339
2427
# that's it
0 commit comments