24
24
| Var( sym name )
25
25
| Const( object val, type type )
26
26
| BinOp( binop op, mexpr lhs, mexpr rhs )
27
- | Array( mexpr arg, sym idx ) -- !!This is not enough b/c we'd like to be able to affine-index-access arrays, but is a literal implementation of Fluid update
28
- path = ( aexpr constraints, mexpr tgt )
27
+ | Array( sym name, avar *dims )
28
+ path = ( aexpr nc, aexpr sc, mexpr tgt ) -- perform weak update for now
29
29
env = ( avar *dims, path* paths ) -- This can handle index access uniformly!
30
30
}
31
31
""" ,
@@ -109,6 +109,9 @@ def validateAbsEnv(obj):
109
109
# Top Level Call to Dataflow analysis
110
110
# --------------------------------------------------------------------------- #
111
111
112
+ aexpr_false = A .Const (False , T .bool , null_srcinfo ())
113
+ aexpr_true = A .Const (True , T .bool , null_srcinfo ())
114
+
112
115
113
116
class LoopIR_to_DataflowIR :
114
117
def __init__ (self , proc ):
@@ -257,14 +260,18 @@ def __str__(self):
257
260
if isinstance (self , D .BinOp ):
258
261
return str (self .lhs ) + str (e .op ) + str (self .rhs )
259
262
if isinstance (self , D .Array ):
260
- return str (self .arg ) + "[" + str (self .idx ) + "]"
263
+ dim_str = "["
264
+ for d in self .dims :
265
+ dim_str += str (d )
266
+ dim_str += "]"
267
+ return str (self .name ) + dim_str
261
268
262
269
assert False , "bad case"
263
270
264
271
265
272
@extclass (AbstractDomains .path )
266
273
def __str__ (self ):
267
- return "(" + str (self .constraints ) + ", " + str (self .tgt ) + ")"
274
+ return "(" + str (self .nc ) + ", " + str (self .sc ) + ") : " + str ( self . tgt )
268
275
269
276
270
277
@extclass (AbstractDomains .env )
@@ -291,13 +298,14 @@ def update(env: D.env, rval: list[D.path]):
291
298
merge_paths = []
292
299
for pre_path in env .paths :
293
300
for rval_path in rval :
294
- pre_cons = pre_path .constraints .simplify ()
295
- rval_cons = rval_path .constraints .simplify ()
301
+ pre_cons = pre_path .nc .simplify ()
302
+ rval_cons = rval_path .nc .simplify ()
296
303
297
304
if isinstance (pre_path .tgt , D .Unk ):
298
305
pre_paths .remove (pre_path )
299
306
elif pre_cons == rval_cons :
300
- merge_paths .append (D .path (rval_cons , rval_path .tgt ))
307
+ # TODO: Handle strong update
308
+ merge_paths .append (D .path (rval_cons , rval_path .sc , rval_path .tgt ))
301
309
pre_paths .remove (pre_path )
302
310
rval_paths .remove (rval_path )
303
311
@@ -308,21 +316,39 @@ def bind_cons(cons: A.expr, rval: list[D.path]):
308
316
new_paths = []
309
317
310
318
for path in rval :
311
- new_cons = A .BinOp ("and" , path .constraints , cons , T .bool , null_srcinfo ())
312
- new_path = D .path (new_cons .simplify (), path .tgt )
319
+ new_nc = A .BinOp ("and" , path .nc , cons , T .bool , null_srcinfo ())
320
+ new_path = D .path (new_nc .simplify (), path . sc , path .tgt )
313
321
new_paths .append (new_path )
314
322
315
323
return new_paths
316
324
317
325
326
+ def propagate_cons (cons : A .expr , env : D .env ):
327
+ return D .env (env .dims , bind_cons (cons , env .paths ))
328
+
329
+
330
+ def ir_to_aexpr (e : DataflowIR .expr ):
331
+ if isinstance (e , DataflowIR .Const ):
332
+ ae = A .Const (e .val , e .type , null_srcinfo ())
333
+ elif isinstance (e , DataflowIR .Read ):
334
+ ae = A .Var (e .name , e .type , null_srcinfo ())
335
+ elif isinstance (e , Sym ):
336
+ ae = A .Var (e , T .index , null_srcinfo ())
337
+ else :
338
+ assert False , f"got { e } of type { type (e )} "
339
+
340
+ return ae
341
+
342
+
318
343
class AbstractInterpretation (ABC ):
319
344
def __init__ (self , proc : DataflowIR .proc ):
320
345
self .proc = proc
321
346
322
347
# setup initial values
323
348
init_env = self .proc .body .ctxts [0 ]
324
349
for a in proc .args :
325
- init_env [a .name ] = self .abs_init_val (a .name , a .type )
350
+ if a .type .is_numeric ():
351
+ init_env [a .name ] = self .abs_init_val (a .name , a .type )
326
352
327
353
# We probably ought to somehow use precondition assertions
328
354
# TODO: leave it for now
@@ -356,14 +382,7 @@ def fix_stmt(self, pre_env, stmt: DataflowIR.stmt, post_env):
356
382
# Handle constraints
357
383
cons = A .Const (True , T .bool , null_srcinfo ())
358
384
for b , e in zip (pre_env [stmt .name ].dims , stmt .idx ):
359
- # TODO!!: Replace this with a general pass to convert DataflowIR to Aexpr
360
- if isinstance (e , DataflowIR .Const ):
361
- e = A .Const (e .val , e .type , null_srcinfo ())
362
- elif isinstance (e , DataflowIR .Read ):
363
- e = A .Var (e .name , e .type , null_srcinfo ())
364
- else :
365
- assert False , "???"
366
- eq = A .BinOp ("==" , b , e , T .bool , null_srcinfo ())
385
+ eq = A .BinOp ("==" , b , ir_to_aexpr (e ), T .bool , null_srcinfo ())
367
386
cons = A .BinOp ("and" , cons , eq , T .bool , null_srcinfo ())
368
387
369
388
rval = bind_cons (cons , rval )
@@ -421,28 +440,31 @@ def fix_stmt(self, pre_env, stmt: DataflowIR.stmt, post_env):
421
440
elif isinstance (stmt , DataflowIR .For ):
422
441
# TODO: Add support for loop-condition analysis in some way?
423
442
424
- # set up the loop body for fixed-point iteration
425
443
pre_body = stmt .body .ctxts [0 ]
444
+ iter_cons = self .abs_iter_val (
445
+ ir_to_aexpr (stmt .iter ), ir_to_aexpr (stmt .lo ), ir_to_aexpr (stmt .hi )
446
+ )
426
447
for nm , val in pre_env .items ():
427
- pre_body [nm ] = val
428
- # initialize the loop iteration variable
429
- lo = self .fix_expr (pre_env , stmt .lo )
430
- hi = self .fix_expr (pre_env , stmt .hi )
431
- pre_body [stmt .iter ] = self .abs_iter_val (lo , hi )
448
+ pre_body [nm ] = propagate_cons (iter_cons , val )
432
449
450
+ # Commenting out the following. We don't need to run a fixed-point
451
+
452
+ # set up the loop body for fixed-point iteration
433
453
# run this loop until we reach a fixed-point
434
- at_fixed_point = False
435
- while not at_fixed_point :
436
- # propagate in the loop
437
- self .fix_block (stmt .body )
438
- at_fixed_point = True
439
- # copy the post-values for the loop back around to
440
- # the pre-values, by joining them together
441
- for nm , prev_val in pre_body .items ():
442
- next_val = stmt .body .ctxts [- 1 ][nm ]
443
- val = self .abs_join (prev_val , next_val )
444
- at_fixed_point = at_fixed_point and prev_val == val
445
- pre_body [nm ] = val
454
+ # at_fixed_point = False
455
+ # while not at_fixed_point:
456
+ # propagate in the loop
457
+ # self.fix_block(stmt.body)
458
+ # at_fixed_point = True
459
+ # copy the post-values for the loop back around to
460
+ # the pre-values, by joining them together
461
+ # for nm, prev_val in pre_body.items():
462
+ # next_val = stmt.body.ctxts[-1][nm]
463
+ # val = self.abs_join(prev_val, next_val)
464
+ # at_fixed_point = at_fixed_point and prev_val == val
465
+ # pre_body[nm] = val
466
+
467
+ self .fix_block (stmt .body )
446
468
447
469
# determine the post-env as join of pre-env and loop results
448
470
for nm , pre_val in pre_env .items ():
@@ -496,7 +518,7 @@ def abs_alloc_val(self, name, typ):
496
518
"""Define initial value of an allocation"""
497
519
498
520
@abstractmethod
499
- def abs_iter_val (self , lo , hi ):
521
+ def abs_iter_val (self , name , lo , hi ):
500
522
"""Define value of an iteration variable"""
501
523
502
524
@abstractmethod
@@ -525,11 +547,15 @@ def abs_builtin(self, builtin, args):
525
547
526
548
527
549
def make_empty_path (me : D .mexpr ) -> D .path :
528
- return [D .path (A . Const ( True , T . bool , null_srcinfo ()) , me )]
550
+ return [D .path (aexpr_true , aexpr_false , me )]
529
551
530
552
531
553
def make_unk () -> D .path :
532
- return [D .path (A .Const (True , T .bool , null_srcinfo ()), D .Unk ())]
554
+ return [D .path (aexpr_true , aexpr_false , D .Unk ())]
555
+
556
+
557
+ def make_unk_array (buf_name : Sym , dims : list ) -> D .path :
558
+ return [D .path (aexpr_true , aexpr_false , D .Array (buf_name , dims ))]
533
559
534
560
535
561
class ConstantPropagation (AbstractInterpretation ):
@@ -540,7 +566,7 @@ def abs_init_val(self, name, typ):
540
566
dims .append (
541
567
A .Var (Sym (name .name () + "_" + str (i )), T .index , null_srcinfo ())
542
568
)
543
- return D .env (dims , make_unk ( ))
569
+ return D .env (dims , make_unk_array ( name , dims ))
544
570
else :
545
571
return D .env ([], make_unk ())
546
572
@@ -551,12 +577,14 @@ def abs_alloc_val(self, name, typ):
551
577
dims .append (
552
578
A .Var (Sym (name .name () + "_" + str (i )), T .index , null_srcinfo ())
553
579
)
554
- return D .env (dims , make_unk ( ))
580
+ return D .env (dims , make_unk_array ( name , dims ))
555
581
else :
556
582
return D .env ([], make_unk ())
557
583
558
- def abs_iter_val (self , lo , hi ):
559
- return D .env ([], make_unk ())
584
+ def abs_iter_val (self , name , lo , hi ):
585
+ lo_cons = A .BinOp ("<=" , lo , name , T .index , null_srcinfo ())
586
+ hi_cons = A .BinOp ("<" , name , hi , T .index , null_srcinfo ())
587
+ return AAnd (lo_cons , hi_cons )
560
588
561
589
def abs_stride_expr (self , name , dim ):
562
590
assert False , "unimplemented"
0 commit comments