Skip to content

Commit 959d90c

Browse files
committed
update
1 parent 7a152a5 commit 959d90c

File tree

2 files changed

+89
-43
lines changed

2 files changed

+89
-43
lines changed

src/exo/dataflow.py

Lines changed: 71 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
| Var( sym name )
2525
| Const( object val, type type )
2626
| BinOp( binop op, mexpr lhs, mexpr rhs )
27-
| Array( mexpr arg, sym idx ) -- !!This is not enough b/c we'd like to be able to affine-index-access arrays, but is a literal implementation of Fluid update
28-
path = ( aexpr constraints, mexpr tgt )
27+
| Array( sym name, avar *dims )
28+
path = ( aexpr nc, aexpr sc, mexpr tgt ) -- perform weak update for now
2929
env = ( avar *dims, path* paths ) -- This can handle index access uniformly!
3030
}
3131
""",
@@ -109,6 +109,9 @@ def validateAbsEnv(obj):
109109
# Top Level Call to Dataflow analysis
110110
# --------------------------------------------------------------------------- #
111111

112+
aexpr_false = A.Const(False, T.bool, null_srcinfo())
113+
aexpr_true = A.Const(True, T.bool, null_srcinfo())
114+
112115

113116
class LoopIR_to_DataflowIR:
114117
def __init__(self, proc):
@@ -257,14 +260,18 @@ def __str__(self):
257260
if isinstance(self, D.BinOp):
258261
return str(self.lhs) + str(e.op) + str(self.rhs)
259262
if isinstance(self, D.Array):
260-
return str(self.arg) + "[" + str(self.idx) + "]"
263+
dim_str = "["
264+
for d in self.dims:
265+
dim_str += str(d)
266+
dim_str += "]"
267+
return str(self.name) + dim_str
261268

262269
assert False, "bad case"
263270

264271

265272
@extclass(AbstractDomains.path)
266273
def __str__(self):
267-
return "(" + str(self.constraints) + ", " + str(self.tgt) + ")"
274+
return "(" + str(self.nc) + ", " + str(self.sc) + ") : " + str(self.tgt)
268275

269276

270277
@extclass(AbstractDomains.env)
@@ -291,13 +298,14 @@ def update(env: D.env, rval: list[D.path]):
291298
merge_paths = []
292299
for pre_path in env.paths:
293300
for rval_path in rval:
294-
pre_cons = pre_path.constraints.simplify()
295-
rval_cons = rval_path.constraints.simplify()
301+
pre_cons = pre_path.nc.simplify()
302+
rval_cons = rval_path.nc.simplify()
296303

297304
if isinstance(pre_path.tgt, D.Unk):
298305
pre_paths.remove(pre_path)
299306
elif pre_cons == rval_cons:
300-
merge_paths.append(D.path(rval_cons, rval_path.tgt))
307+
# TODO: Handle strong update
308+
merge_paths.append(D.path(rval_cons, rval_path.sc, rval_path.tgt))
301309
pre_paths.remove(pre_path)
302310
rval_paths.remove(rval_path)
303311

@@ -308,21 +316,39 @@ def bind_cons(cons: A.expr, rval: list[D.path]):
308316
new_paths = []
309317

310318
for path in rval:
311-
new_cons = A.BinOp("and", path.constraints, cons, T.bool, null_srcinfo())
312-
new_path = D.path(new_cons.simplify(), path.tgt)
319+
new_nc = A.BinOp("and", path.nc, cons, T.bool, null_srcinfo())
320+
new_path = D.path(new_nc.simplify(), path.sc, path.tgt)
313321
new_paths.append(new_path)
314322

315323
return new_paths
316324

317325

326+
def propagate_cons(cons: A.expr, env: D.env):
327+
return D.env(env.dims, bind_cons(cons, env.paths))
328+
329+
330+
def ir_to_aexpr(e: DataflowIR.expr):
331+
if isinstance(e, DataflowIR.Const):
332+
ae = A.Const(e.val, e.type, null_srcinfo())
333+
elif isinstance(e, DataflowIR.Read):
334+
ae = A.Var(e.name, e.type, null_srcinfo())
335+
elif isinstance(e, Sym):
336+
ae = A.Var(e, T.index, null_srcinfo())
337+
else:
338+
assert False, f"got {e} of type {type(e)}"
339+
340+
return ae
341+
342+
318343
class AbstractInterpretation(ABC):
319344
def __init__(self, proc: DataflowIR.proc):
320345
self.proc = proc
321346

322347
# setup initial values
323348
init_env = self.proc.body.ctxts[0]
324349
for a in proc.args:
325-
init_env[a.name] = self.abs_init_val(a.name, a.type)
350+
if a.type.is_numeric():
351+
init_env[a.name] = self.abs_init_val(a.name, a.type)
326352

327353
# We probably ought to somehow use precondition assertions
328354
# TODO: leave it for now
@@ -356,14 +382,7 @@ def fix_stmt(self, pre_env, stmt: DataflowIR.stmt, post_env):
356382
# Handle constraints
357383
cons = A.Const(True, T.bool, null_srcinfo())
358384
for b, e in zip(pre_env[stmt.name].dims, stmt.idx):
359-
# TODO!!: Replace this with a general pass to convert DataflowIR to Aexpr
360-
if isinstance(e, DataflowIR.Const):
361-
e = A.Const(e.val, e.type, null_srcinfo())
362-
elif isinstance(e, DataflowIR.Read):
363-
e = A.Var(e.name, e.type, null_srcinfo())
364-
else:
365-
assert False, "???"
366-
eq = A.BinOp("==", b, e, T.bool, null_srcinfo())
385+
eq = A.BinOp("==", b, ir_to_aexpr(e), T.bool, null_srcinfo())
367386
cons = A.BinOp("and", cons, eq, T.bool, null_srcinfo())
368387

369388
rval = bind_cons(cons, rval)
@@ -421,28 +440,31 @@ def fix_stmt(self, pre_env, stmt: DataflowIR.stmt, post_env):
421440
elif isinstance(stmt, DataflowIR.For):
422441
# TODO: Add support for loop-condition analysis in some way?
423442

424-
# set up the loop body for fixed-point iteration
425443
pre_body = stmt.body.ctxts[0]
444+
iter_cons = self.abs_iter_val(
445+
ir_to_aexpr(stmt.iter), ir_to_aexpr(stmt.lo), ir_to_aexpr(stmt.hi)
446+
)
426447
for nm, val in pre_env.items():
427-
pre_body[nm] = val
428-
# initialize the loop iteration variable
429-
lo = self.fix_expr(pre_env, stmt.lo)
430-
hi = self.fix_expr(pre_env, stmt.hi)
431-
pre_body[stmt.iter] = self.abs_iter_val(lo, hi)
448+
pre_body[nm] = propagate_cons(iter_cons, val)
432449

450+
# Commenting out the following. We don't need to run a fixed-point
451+
452+
# set up the loop body for fixed-point iteration
433453
# run this loop until we reach a fixed-point
434-
at_fixed_point = False
435-
while not at_fixed_point:
436-
# propagate in the loop
437-
self.fix_block(stmt.body)
438-
at_fixed_point = True
439-
# copy the post-values for the loop back around to
440-
# the pre-values, by joining them together
441-
for nm, prev_val in pre_body.items():
442-
next_val = stmt.body.ctxts[-1][nm]
443-
val = self.abs_join(prev_val, next_val)
444-
at_fixed_point = at_fixed_point and prev_val == val
445-
pre_body[nm] = val
454+
# at_fixed_point = False
455+
# while not at_fixed_point:
456+
# propagate in the loop
457+
# self.fix_block(stmt.body)
458+
# at_fixed_point = True
459+
# copy the post-values for the loop back around to
460+
# the pre-values, by joining them together
461+
# for nm, prev_val in pre_body.items():
462+
# next_val = stmt.body.ctxts[-1][nm]
463+
# val = self.abs_join(prev_val, next_val)
464+
# at_fixed_point = at_fixed_point and prev_val == val
465+
# pre_body[nm] = val
466+
467+
self.fix_block(stmt.body)
446468

447469
# determine the post-env as join of pre-env and loop results
448470
for nm, pre_val in pre_env.items():
@@ -496,7 +518,7 @@ def abs_alloc_val(self, name, typ):
496518
"""Define initial value of an allocation"""
497519

498520
@abstractmethod
499-
def abs_iter_val(self, lo, hi):
521+
def abs_iter_val(self, name, lo, hi):
500522
"""Define value of an iteration variable"""
501523

502524
@abstractmethod
@@ -525,11 +547,15 @@ def abs_builtin(self, builtin, args):
525547

526548

527549
def make_empty_path(me: D.mexpr) -> D.path:
528-
return [D.path(A.Const(True, T.bool, null_srcinfo()), me)]
550+
return [D.path(aexpr_true, aexpr_false, me)]
529551

530552

531553
def make_unk() -> D.path:
532-
return [D.path(A.Const(True, T.bool, null_srcinfo()), D.Unk())]
554+
return [D.path(aexpr_true, aexpr_false, D.Unk())]
555+
556+
557+
def make_unk_array(buf_name: Sym, dims: list) -> D.path:
558+
return [D.path(aexpr_true, aexpr_false, D.Array(buf_name, dims))]
533559

534560

535561
class ConstantPropagation(AbstractInterpretation):
@@ -540,7 +566,7 @@ def abs_init_val(self, name, typ):
540566
dims.append(
541567
A.Var(Sym(name.name() + "_" + str(i)), T.index, null_srcinfo())
542568
)
543-
return D.env(dims, make_unk())
569+
return D.env(dims, make_unk_array(name, dims))
544570
else:
545571
return D.env([], make_unk())
546572

@@ -551,12 +577,14 @@ def abs_alloc_val(self, name, typ):
551577
dims.append(
552578
A.Var(Sym(name.name() + "_" + str(i)), T.index, null_srcinfo())
553579
)
554-
return D.env(dims, make_unk())
580+
return D.env(dims, make_unk_array(name, dims))
555581
else:
556582
return D.env([], make_unk())
557583

558-
def abs_iter_val(self, lo, hi):
559-
return D.env([], make_unk())
584+
def abs_iter_val(self, name, lo, hi):
585+
lo_cons = A.BinOp("<=", lo, name, T.index, null_srcinfo())
586+
hi_cons = A.BinOp("<", name, hi, T.index, null_srcinfo())
587+
return AAnd(lo_cons, hi_cons)
560588

561589
def abs_stride_expr(self, name, dim):
562590
assert False, "unimplemented"

tests/test_dataflow.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,3 +182,21 @@ def foo(n: size, a: f32, c: f32):
182182
print()
183183
print(foo.dataflow())
184184
print()
185+
186+
187+
def test_absval_init():
188+
@proc
189+
def foo(n: size, dst: f32[n]):
190+
for i in seq(0, n):
191+
dst[i] = 0.0
192+
193+
print()
194+
print(foo.dataflow())
195+
196+
@proc
197+
def foo(n: size, dst: f32[n], src: f32[n]):
198+
for i in seq(0, n):
199+
dst[i] = src[i]
200+
201+
print()
202+
print(foo.dataflow())

0 commit comments

Comments
 (0)