@@ -150,22 +150,55 @@ def __init__(self, outer, access):
150
150
self .access = access
151
151
152
152
def kernel_arg (self , loop_indices = None ):
153
- return Indexed (self .outer , (Index (e ) for e in self .outer .shape ))
153
+ pack = self .pack (loop_indices )
154
+ return Indexed (pack , (Index (e ) for e in pack .shape ))
154
155
155
156
def emit_pack_instruction (self , * , loop_indices = None ):
157
+ return ()
158
+
159
+ def pack (self , loop_indices = None ):
160
+ if hasattr (self , "_pack" ):
161
+ return self ._pack
162
+
156
163
shape = self .outer .shape
157
- if self .access is WRITE :
158
- zero = Zero ((), self .outer .dtype )
164
+ if self .access is READ :
165
+ # No packing required
166
+ return self .outer
167
+ # We don't need to pack for memory layout, however packing
168
+ # globals that are written is required such that subsequent
169
+ # vectorisation loop transformations privatise these reduction
170
+ # variables. The extra memory movement cost is minimal.
171
+ loop_indices = self .pick_loop_indices (* loop_indices )
172
+ if self .access in {INC , WRITE }:
173
+ val = Zero ((), self .outer .dtype )
174
+ multiindex = MultiIndex (* (Index (e ) for e in shape ))
175
+ self ._pack = Materialise (PackInst (loop_indices ), val , multiindex )
176
+ elif self .access in {READ , RW , MIN , MAX }:
159
177
multiindex = MultiIndex (* (Index (e ) for e in shape ))
160
- yield Accumulate (PackInst (), Indexed (self .outer , multiindex ), zero )
178
+ expr = Indexed (self .outer , multiindex )
179
+ self ._pack = Materialise (PackInst (loop_indices ), expr , multiindex )
161
180
else :
162
- return ()
163
-
164
- def pack (self , loop_indices = None ):
165
- return None
181
+ raise ValueError ("Don't know how to initialise pack for '%s' access" % self .access )
182
+ return self ._pack
166
183
167
184
def emit_unpack_instruction (self , * , loop_indices = None ):
168
- return ()
185
+ pack = self .pack (loop_indices )
186
+ loop_indices = self .pick_loop_indices (* loop_indices )
187
+ if pack is None :
188
+ return ()
189
+ elif self .access is READ :
190
+ return ()
191
+ elif self .access in {INC , MIN , MAX }:
192
+ op = {INC : Sum ,
193
+ MIN : Min ,
194
+ MAX : Max }[self .access ]
195
+ multiindex = tuple (Index (e ) for e in pack .shape )
196
+ rvalue = Indexed (self .outer , multiindex )
197
+ yield Accumulate (UnpackInst (loop_indices ), rvalue , op (rvalue , Indexed (pack , multiindex )))
198
+ else :
199
+ multiindex = tuple (Index (e ) for e in pack .shape )
200
+ rvalue = Indexed (self .outer , multiindex )
201
+ yield Accumulate (UnpackInst (loop_indices ), rvalue , Indexed (pack , multiindex ))
169
202
170
203
171
204
class DatPack (Pack ):
0 commit comments