@@ -145,9 +145,10 @@ def emit_unpack_instruction(self, *, loop_indices=None):
145
145
146
146
class GlobalPack (Pack ):
147
147
148
- def __init__ (self , outer , access ):
148
+ def __init__ (self , outer , access , init_with_zero = False ):
149
149
self .outer = outer
150
150
self .access = access
151
+ self .init_with_zero = init_with_zero
151
152
152
153
def kernel_arg (self , loop_indices = None ):
153
154
pack = self .pack (loop_indices )
@@ -169,11 +170,15 @@ def pack(self, loop_indices=None):
169
170
# vectorisation loop transformations privatise these reduction
170
171
# variables. The extra memory movement cost is minimal.
171
172
loop_indices = self .pick_loop_indices (* loop_indices )
172
- if self .access in {INC , WRITE }:
173
+ if self .init_with_zero :
174
+ also_zero = {MIN , MAX }
175
+ else :
176
+ also_zero = set ()
177
+ if self .access in {INC , WRITE } | also_zero :
173
178
val = Zero ((), self .outer .dtype )
174
179
multiindex = MultiIndex (* (Index (e ) for e in shape ))
175
180
self ._pack = Materialise (PackInst (loop_indices ), val , multiindex )
176
- elif self .access in {READ , RW , MIN , MAX }:
181
+ elif self .access in {READ , RW , MIN , MAX } - also_zero :
177
182
multiindex = MultiIndex (* (Index (e ) for e in shape ))
178
183
expr = Indexed (self .outer , multiindex )
179
184
self ._pack = Materialise (PackInst (loop_indices ), expr , multiindex )
@@ -203,13 +208,15 @@ def emit_unpack_instruction(self, *, loop_indices=None):
203
208
204
209
class DatPack (Pack ):
205
210
def __init__ (self , outer , access , map_ = None , interior_horizontal = False ,
206
- view_index = None , layer_bounds = None ):
211
+ view_index = None , layer_bounds = None ,
212
+ init_with_zero = False ):
207
213
self .outer = outer
208
214
self .map_ = map_
209
215
self .access = access
210
216
self .interior_horizontal = interior_horizontal
211
217
self .view_index = view_index
212
218
self .layer_bounds = layer_bounds
219
+ self .init_with_zero = init_with_zero
213
220
214
221
def _mask (self , map_ ):
215
222
"""Override this if the map_ needs a masking condition."""
@@ -245,11 +252,15 @@ def pack(self, loop_indices=None):
245
252
if self .view_index is None :
246
253
shape = shape + self .outer .shape [1 :]
247
254
248
- if self .access in {INC , WRITE }:
255
+ if self .init_with_zero :
256
+ also_zero = {MIN , MAX }
257
+ else :
258
+ also_zero = set ()
259
+ if self .access in {INC , WRITE } | also_zero :
249
260
val = Zero ((), self .outer .dtype )
250
261
multiindex = MultiIndex (* (Index (e ) for e in shape ))
251
262
self ._pack = Materialise (PackInst (), val , multiindex )
252
- elif self .access in {READ , RW , MIN , MAX }:
263
+ elif self .access in {READ , RW , MIN , MAX } - also_zero :
253
264
multiindex = MultiIndex (* (Index (e ) for e in shape ))
254
265
expr , mask = self ._rvalue (multiindex , loop_indices = loop_indices )
255
266
if mask is not None :
@@ -577,8 +588,9 @@ def emit_unpack_instruction(self, *,
577
588
578
589
class WrapperBuilder (object ):
579
590
580
- def __init__ (self , * , iterset , iteration_region = None , single_cell = False ,
591
+ def __init__ (self , * , kernel , iterset , iteration_region = None , single_cell = False ,
581
592
pass_layer_to_kernel = False , forward_arg_types = ()):
593
+ self .kernel = kernel
582
594
self .arguments = []
583
595
self .argument_accesses = []
584
596
self .packed_args = []
@@ -593,6 +605,10 @@ def __init__(self, *, iterset, iteration_region=None, single_cell=False,
593
605
self .single_cell = single_cell
594
606
self .forward_arguments = tuple (Argument ((), fa , pfx = "farg" ) for fa in forward_arg_types )
595
607
608
+ @property
609
+ def requires_zeroed_output_arguments (self ):
610
+ return self .kernel .requires_zeroed_output_arguments
611
+
596
612
@property
597
613
def subset (self ):
598
614
return isinstance (self .iterset , Subset )
@@ -605,9 +621,6 @@ def extruded(self):
605
621
def constant_layers (self ):
606
622
return self .extruded and self .iterset .constant_layers
607
623
608
- def set_kernel (self , kernel ):
609
- self .kernel = kernel
610
-
611
624
@cached_property
612
625
def loop_extents (self ):
613
626
return (Argument ((), IntType , name = "start" ),
@@ -722,7 +735,8 @@ def add_argument(self, arg):
722
735
shape = (None , * a .data .shape [1 :])
723
736
argument = Argument (shape , a .data .dtype , pfx = "mdat" )
724
737
packs .append (a .data .pack (argument , arg .access , self .map_ (a .map , unroll = a .unroll_map ),
725
- interior_horizontal = interior_horizontal ))
738
+ interior_horizontal = interior_horizontal ,
739
+ init_with_zero = self .requires_zeroed_output_arguments ))
726
740
self .arguments .append (argument )
727
741
pack = MixedDatPack (packs , arg .access , arg .dtype , interior_horizontal = interior_horizontal )
728
742
self .packed_args .append (pack )
@@ -740,15 +754,17 @@ def add_argument(self, arg):
740
754
pfx = "dat" )
741
755
pack = arg .data .pack (argument , arg .access , self .map_ (arg .map , unroll = arg .unroll_map ),
742
756
interior_horizontal = interior_horizontal ,
743
- view_index = view_index )
757
+ view_index = view_index ,
758
+ init_with_zero = self .requires_zeroed_output_arguments )
744
759
self .arguments .append (argument )
745
760
self .packed_args .append (pack )
746
761
self .argument_accesses .append (arg .access )
747
762
elif arg ._is_global :
748
763
argument = Argument (arg .data .dim ,
749
764
arg .data .dtype ,
750
765
pfx = "glob" )
751
- pack = GlobalPack (argument , arg .access )
766
+ pack = GlobalPack (argument , arg .access ,
767
+ init_with_zero = self .requires_zeroed_output_arguments )
752
768
self .arguments .append (argument )
753
769
self .packed_args .append (pack )
754
770
self .argument_accesses .append (arg .access )
0 commit comments