Skip to content

Commit cededcd

Browse files
committed
Sketch out add_lexicographic_happens_after
1 parent 93b4126 commit cededcd

File tree

1 file changed

+80
-3
lines changed

1 file changed

+80
-3
lines changed

loopy/kernel/dependency.py

Lines changed: 80 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1+
# FIXME Add copyright header
2+
3+
14
import islpy as isl
5+
from islpy import dim_type
26
import pymbolic.primitives as p
37

48
from dataclasses import dataclass
@@ -7,6 +11,8 @@
711

812
from loopy import LoopKernel
913
from loopy.symbolic import WalkMapper
14+
from loopy.translation_unit import for_each_kernel
15+
from loopy.typing import ExpressionT
1016

1117
@dataclass(frozen=True)
1218
class HappensAfter:
@@ -32,7 +38,7 @@ def __init__(self, kernel: LoopKernel, var_names: set):
3238

3339
super.__init__()
3440

35-
def map_subscript(self, expr: p.expression, inames: frozenset, insn_id: str):
41+
def map_subscript(self, expr: ExpressionT, inames: frozenset, insn_id: str):
3642

3743
domain = self.kernel.get_inames_domain(inames)
3844

@@ -110,7 +116,7 @@ def compute_happens_after(knl: LoopKernel) -> LoopKernel:
110116
# return the kernel with the new instructions
111117
return knl.copy(instructions=new_insns)
112118

113-
def add_lexicographic_happens_after(knl: LoopKernel) -> None:
119+
def add_lexicographic_happens_after_orig(knl: LoopKernel) -> None:
114120
"""
115121
TODO properly format this documentation.
116122
@@ -122,7 +128,7 @@ def add_lexicographic_happens_after(knl: LoopKernel) -> None:
122128
"""
123129

124130
# we want to modify the output dimension and OUT = 3
125-
dim_type = isl.dim_type(3)
131+
dim_type = isl.dim_type.out
126132

127133
# generate an unordered mapping from statement instances to points in the
128134
# loop domain
@@ -148,3 +154,74 @@ def add_lexicographic_happens_after(knl: LoopKernel) -> None:
148154

149155
# determine a lexicographic order on the space the schedules belong to
150156

157+
158+
@for_each_kernel
159+
def add_lexicographic_happens_after(knl: LoopKernel) -> LoopKernel:
160+
161+
new_insns = []
162+
163+
for iafter, insn_after in enumerate(knl.instructions):
164+
if iafter == 0:
165+
new_insns.append(insn_after)
166+
else:
167+
insn_before = knl.instructions[iafter - 1]
168+
shared_inames = insn_after.within_inames & insn_before.within_inames
169+
unshared_before = insn_before.within_inames
170+
171+
domain_before = knl.get_inames_domain(insn_before.within_inames)
172+
domain_after = knl.get_inames_domain(insn_after.within_inames)
173+
174+
happens_before = isl.Map.from_domain_and_range(
175+
domain_before, domain_after)
176+
for idim in range(happens_before.dim(dim_type.out)):
177+
happens_before = happens_before.set_dim_name(
178+
dim_type.out, idim,
179+
happens_before.get_dim_name(dim_type.out, idim) + "'")
180+
n_inames_before = happens_before.dim(dim_type.in_)
181+
happens_before_set = happens_before.move_dims(
182+
dim_type.out, 0,
183+
dim_type.in_, 0,
184+
n_inames_before).range()
185+
186+
shared_inames_order_before = [
187+
domain_before.get_dim_name(dim_type.out, idim)
188+
for idim in range(domain_before.dim(dim_type.out))
189+
if domain_before.get_dim_name(dim_type.out, idim)
190+
in shared_inames]
191+
shared_inames_order_after = [
192+
domain_after.get_dim_name(dim_type.out, idim)
193+
for idim in range(domain_after.dim(dim_type.out))
194+
if domain_after.get_dim_name(dim_type.out, idim)
195+
in shared_inames]
196+
197+
assert shared_inames_order_after == shared_inames_order_before
198+
shared_inames_order = shared_inames_order_after
199+
200+
affs = isl.affs_from_space(happens_before_set.space)
201+
202+
lex_set = isl.Set.empty(happens_before_set.space)
203+
for iinnermost, innermost_iname in enumerate(shared_inames_order):
204+
innermost_set = affs[innermost_iname].lt_set(
205+
affs[innermost_iname+"'"])
206+
207+
for outer_iname in shared_inames_order[:iinnermost]:
208+
innermost_set = innermost_set & (
209+
affs[outer_iname].eq_set(affs[outer_iname + "'"]))
210+
211+
lex_set = lex_set | innermost_set
212+
213+
lex_map = isl.Map.from_range(lex_set).move_dims(
214+
dim_type.in_, 0,
215+
dim_type.out, 0,
216+
n_inames_before)
217+
218+
happens_before = happens_before & lex_map
219+
220+
pu.db
221+
222+
new_insns.append(insn_after)
223+
224+
return knl.copy(instructions=new_insns)
225+
226+
227+

0 commit comments

Comments
 (0)