Skip to content

Commit 8823473

Browse files
committed
add tests
1 parent fd200ab commit 8823473

File tree

2 files changed

+63
-2
lines changed

2 files changed

+63
-2
lines changed

src/exo/dataflow.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -382,8 +382,10 @@ def fix_stmt(self, pre_env, stmt: DataflowIR.stmt, post_env):
382382
# if reducing, then expand to x = x + rhs
383383
rhs_e = stmt.rhs
384384
if isinstance(stmt, DataflowIR.Reduce):
385-
read_buf = DataflowIR.Read(stmt.name, stmt.idx)
386-
rhs_e = DataflowIR.BinOp("+", read_buf, rhs_e)
385+
read_buf = DataflowIR.Read(
386+
stmt.name, stmt.idx, rhs_e.type, stmt.srcinfo
387+
)
388+
rhs_e = DataflowIR.BinOp("+", read_buf, rhs_e, rhs_e.type, stmt.srcinfo)
387389
# now we can handle both cases uniformly
388390
rval = self.fix_expr(pre_env, rhs_e)
389391
# need to be careful for buffers (no overwrite guarantee)

tests/test_dataflow.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,62 @@ def foo(x: R[3], y: R[3], z: R):
8383
print()
8484
print(foo.dataflow())
8585
print()
86+
87+
88+
# TODO: Currently add_unsafe_guard lacks analysis, but we should be able to analyze this
89+
def test_sliding_window():
90+
@proc
91+
def foo(n: size, m: size, dst: i8[n + m], src: i8[n + m]):
92+
for i in seq(0, n):
93+
for j in seq(0, m):
94+
dst[i + j] = src[i + j]
95+
96+
foo = add_unsafe_guard(foo, "dst[_] = src[_]", "i == 0 or j == m - 1")
97+
print()
98+
print(foo.dataflow())
99+
print()
100+
101+
102+
# TODO: fission should be able to handle this
103+
def test_fission_fail():
104+
@proc
105+
def foo(n: size, dst: i8[n + 1], src: i8[n + 1]):
106+
for i in seq(0, n):
107+
dst[i] = src[i]
108+
dst[i + 1] = src[i + 1]
109+
110+
with pytest.raises(SchedulingError, match="Cannot fission"):
111+
foo = fission(foo, foo.find("dst[i] = _").after())
112+
print(foo)
113+
114+
115+
# TODO: This is unsafe, lift_alloc should give an error
116+
def test_lift_alloc_unsafe():
117+
@proc
118+
def foo():
119+
for i in seq(0, 10):
120+
a: i8[11] @ DRAM
121+
a[i] = 1.0
122+
a[i + 1] += 1.0
123+
124+
foo = lift_alloc(foo, "a : _")
125+
print()
126+
print(foo.dataflow())
127+
print()
128+
129+
130+
# TODO: We are not supporting this AFAIK but should keep this example in mind
131+
def test_reduc():
132+
@proc
133+
def foo(n: size, a: f32, c: f32):
134+
tmp: f32[n]
135+
for i in seq(0, n):
136+
for j in seq(0, 4):
137+
tmp[i] = a
138+
a = tmp[i] + 1.0
139+
for i in seq(0, n):
140+
c += tmp[i] # some use of tmp
141+
142+
print()
143+
print(foo.dataflow())
144+
print()

0 commit comments

Comments
 (0)