1
+ # FIXME Add copyright header
2
+
3
+
1
4
import islpy as isl
5
+ from islpy import dim_type
2
6
import pymbolic .primitives as p
3
7
4
8
from dataclasses import dataclass
7
11
8
12
from loopy import LoopKernel
9
13
from loopy .symbolic import WalkMapper
14
+ from loopy .translation_unit import for_each_kernel
15
+ from loopy .typing import ExpressionT
10
16
11
17
@dataclass (frozen = True )
12
18
class HappensAfter :
@@ -32,7 +38,7 @@ def __init__(self, kernel: LoopKernel, var_names: set):
32
38
33
39
super .__init__ ()
34
40
35
- def map_subscript (self , expr : p . expression , inames : frozenset , insn_id : str ):
41
+ def map_subscript (self , expr : ExpressionT , inames : frozenset , insn_id : str ):
36
42
37
43
domain = self .kernel .get_inames_domain (inames )
38
44
@@ -110,7 +116,7 @@ def compute_happens_after(knl: LoopKernel) -> LoopKernel:
110
116
# return the kernel with the new instructions
111
117
return knl .copy (instructions = new_insns )
112
118
113
- def add_lexicographic_happens_after (knl : LoopKernel ) -> None :
119
+ def add_lexicographic_happens_after_orig (knl : LoopKernel ) -> None :
114
120
"""
115
121
TODO properly format this documentation.
116
122
@@ -122,7 +128,7 @@ def add_lexicographic_happens_after(knl: LoopKernel) -> None:
122
128
"""
123
129
124
130
# we want to modify the output dimension and OUT = 3
125
- dim_type = isl .dim_type ( 3 )
131
+ dim_type = isl .dim_type . out
126
132
127
133
# generate an unordered mapping from statement instances to points in the
128
134
# loop domain
@@ -148,3 +154,74 @@ def add_lexicographic_happens_after(knl: LoopKernel) -> None:
148
154
149
155
# determine a lexicographic order on the space the schedules belong to
150
156
157
+
158
+ @for_each_kernel
159
+ def add_lexicographic_happens_after (knl : LoopKernel ) -> LoopKernel :
160
+
161
+ new_insns = []
162
+
163
+ for iafter , insn_after in enumerate (knl .instructions ):
164
+ if iafter == 0 :
165
+ new_insns .append (insn_after )
166
+ else :
167
+ insn_before = knl .instructions [iafter - 1 ]
168
+ shared_inames = insn_after .within_inames & insn_before .within_inames
169
+ unshared_before = insn_before .within_inames
170
+
171
+ domain_before = knl .get_inames_domain (insn_before .within_inames )
172
+ domain_after = knl .get_inames_domain (insn_after .within_inames )
173
+
174
+ happens_before = isl .Map .from_domain_and_range (
175
+ domain_before , domain_after )
176
+ for idim in range (happens_before .dim (dim_type .out )):
177
+ happens_before = happens_before .set_dim_name (
178
+ dim_type .out , idim ,
179
+ happens_before .get_dim_name (dim_type .out , idim ) + "'" )
180
+ n_inames_before = happens_before .dim (dim_type .in_ )
181
+ happens_before_set = happens_before .move_dims (
182
+ dim_type .out , 0 ,
183
+ dim_type .in_ , 0 ,
184
+ n_inames_before ).range ()
185
+
186
+ shared_inames_order_before = [
187
+ domain_before .get_dim_name (dim_type .out , idim )
188
+ for idim in range (domain_before .dim (dim_type .out ))
189
+ if domain_before .get_dim_name (dim_type .out , idim )
190
+ in shared_inames ]
191
+ shared_inames_order_after = [
192
+ domain_after .get_dim_name (dim_type .out , idim )
193
+ for idim in range (domain_after .dim (dim_type .out ))
194
+ if domain_after .get_dim_name (dim_type .out , idim )
195
+ in shared_inames ]
196
+
197
+ assert shared_inames_order_after == shared_inames_order_before
198
+ shared_inames_order = shared_inames_order_after
199
+
200
+ affs = isl .affs_from_space (happens_before_set .space )
201
+
202
+ lex_set = isl .Set .empty (happens_before_set .space )
203
+ for iinnermost , innermost_iname in enumerate (shared_inames_order ):
204
+ innermost_set = affs [innermost_iname ].lt_set (
205
+ affs [innermost_iname + "'" ])
206
+
207
+ for outer_iname in shared_inames_order [:iinnermost ]:
208
+ innermost_set = innermost_set & (
209
+ affs [outer_iname ].eq_set (affs [outer_iname + "'" ]))
210
+
211
+ lex_set = lex_set | innermost_set
212
+
213
+ lex_map = isl .Map .from_range (lex_set ).move_dims (
214
+ dim_type .in_ , 0 ,
215
+ dim_type .out , 0 ,
216
+ n_inames_before )
217
+
218
+ happens_before = happens_before & lex_map
219
+
220
+ pu .db
221
+
222
+ new_insns .append (insn_after )
223
+
224
+ return knl .copy (instructions = new_insns )
225
+
226
+
227
+
0 commit comments