@@ -51,7 +51,10 @@ def __init__(self, map_, interior_horizontal, layer_bounds,
51
51
shape = (None , ) + map_ .shape [1 :]
52
52
values = Argument (shape , dtype = map_ .dtype , pfx = "map" )
53
53
if offset is not None :
54
- offset = NamedLiteral (offset , name = values .name + "_offset" )
54
+ if len (set (map_ .offset )) == 1 :
55
+ offset = Literal (offset [0 ], casting = True )
56
+ else :
57
+ offset = NamedLiteral (offset , name = values .name + "_offset" )
55
58
56
59
self .values = values
57
60
self .offset = offset
@@ -68,21 +71,33 @@ def indexed(self, multiindex, layer=None):
68
71
n , i , f = multiindex
69
72
if layer is not None and self .offset is not None :
70
73
# For extruded mesh, prefetch the indirections for each map, so that they don't
71
- # need to be recomputed. Different f values need to be treated separately.
74
+ # need to be recomputed.
75
+ # First prefetch the base map (not dependent on layers)
76
+ base_key = None
77
+ if base_key not in self .prefetch :
78
+ j = Index ()
79
+ base = Indexed (self .values , (n , j ))
80
+ self .prefetch [base_key ] = Materialise (PackInst (), base , MultiIndex (j ))
81
+
82
+ base = self .prefetch [base_key ]
83
+
84
+ # Now prefetch the extruded part of the map (inside the layer loop).
85
+ # This is necessary so loopy DTRT for MatSetValues
86
+ # Different f values need to be treated separately.
72
87
key = f .extent
73
88
if key is None :
74
89
key = 1
75
90
if key not in self .prefetch :
76
91
bottom_layer , _ = self .layer_bounds
77
- offset_extent , = self .offset .shape
78
- j = Index (offset_extent )
79
- base = Indexed (self .values , (n , j ))
80
- if f .extent :
81
- k = Index (f .extent )
82
- else :
83
- k = Index (1 )
92
+ k = Index (f .extent if f .extent is not None else 1 )
84
93
offset = Sum (Sum (layer , Product (Literal (numpy .int32 (- 1 )), bottom_layer )), k )
85
- offset = Product (offset , Indexed (self .offset , (j ,)))
94
+ j = Index ()
95
+ # Inline map offsets where all entries are identical.
96
+ if self .offset .shape == ():
97
+ offset = Product (offset , self .offset )
98
+ else :
99
+ offset = Product (offset , Indexed (self .offset , (j ,)))
100
+ base = Indexed (base , (j , ))
86
101
self .prefetch [key ] = Materialise (PackInst (), Sum (base , offset ), MultiIndex (k , j ))
87
102
88
103
return Indexed (self .prefetch [key ], (f , i )), (f , i )
@@ -130,38 +145,78 @@ def emit_unpack_instruction(self, *, loop_indices=None):
130
145
131
146
class GlobalPack (Pack ):
132
147
133
- def __init__ (self , outer , access ):
148
+ def __init__ (self , outer , access , init_with_zero = False ):
134
149
self .outer = outer
135
150
self .access = access
151
+ self .init_with_zero = init_with_zero
136
152
137
153
def kernel_arg (self , loop_indices = None ):
138
- return Indexed (self .outer , (Index (e ) for e in self .outer .shape ))
154
+ pack = self .pack (loop_indices )
155
+ return Indexed (pack , (Index (e ) for e in pack .shape ))
139
156
140
157
def emit_pack_instruction (self , * , loop_indices = None ):
158
+ return ()
159
+
160
+ def pack (self , loop_indices = None ):
161
+ if hasattr (self , "_pack" ):
162
+ return self ._pack
163
+
141
164
shape = self .outer .shape
142
- if self .access is WRITE :
143
- zero = Zero ((), self .outer .dtype )
165
+ if self .access is READ :
166
+ # No packing required
167
+ return self .outer
168
+ # We don't need to pack for memory layout, however packing
169
+ # globals that are written is required such that subsequent
170
+ # vectorisation loop transformations privatise these reduction
171
+ # variables. The extra memory movement cost is minimal.
172
+ loop_indices = self .pick_loop_indices (* loop_indices )
173
+ if self .init_with_zero :
174
+ also_zero = {MIN , MAX }
175
+ else :
176
+ also_zero = set ()
177
+ if self .access in {INC , WRITE } | also_zero :
178
+ val = Zero ((), self .outer .dtype )
179
+ multiindex = MultiIndex (* (Index (e ) for e in shape ))
180
+ self ._pack = Materialise (PackInst (loop_indices ), val , multiindex )
181
+ elif self .access in {READ , RW , MIN , MAX } - also_zero :
144
182
multiindex = MultiIndex (* (Index (e ) for e in shape ))
145
- yield Accumulate (PackInst (), Indexed (self .outer , multiindex ), zero )
183
+ expr = Indexed (self .outer , multiindex )
184
+ self ._pack = Materialise (PackInst (loop_indices ), expr , multiindex )
146
185
else :
147
- return ()
148
-
149
- def pack (self , loop_indices = None ):
150
- return None
186
+ raise ValueError ("Don't know how to initialise pack for '%s' access" % self .access )
187
+ return self ._pack
151
188
152
189
def emit_unpack_instruction (self , * , loop_indices = None ):
153
- return ()
190
+ pack = self .pack (loop_indices )
191
+ loop_indices = self .pick_loop_indices (* loop_indices )
192
+ if pack is None :
193
+ return ()
194
+ elif self .access is READ :
195
+ return ()
196
+ elif self .access in {INC , MIN , MAX }:
197
+ op = {INC : Sum ,
198
+ MIN : Min ,
199
+ MAX : Max }[self .access ]
200
+ multiindex = tuple (Index (e ) for e in pack .shape )
201
+ rvalue = Indexed (self .outer , multiindex )
202
+ yield Accumulate (UnpackInst (loop_indices ), rvalue , op (rvalue , Indexed (pack , multiindex )))
203
+ else :
204
+ multiindex = tuple (Index (e ) for e in pack .shape )
205
+ rvalue = Indexed (self .outer , multiindex )
206
+ yield Accumulate (UnpackInst (loop_indices ), rvalue , Indexed (pack , multiindex ))
154
207
155
208
156
209
class DatPack (Pack ):
157
210
def __init__ (self , outer , access , map_ = None , interior_horizontal = False ,
158
- view_index = None , layer_bounds = None ):
211
+ view_index = None , layer_bounds = None ,
212
+ init_with_zero = False ):
159
213
self .outer = outer
160
214
self .map_ = map_
161
215
self .access = access
162
216
self .interior_horizontal = interior_horizontal
163
217
self .view_index = view_index
164
218
self .layer_bounds = layer_bounds
219
+ self .init_with_zero = init_with_zero
165
220
166
221
def _mask (self , map_ ):
167
222
"""Override this if the map_ needs a masking condition."""
@@ -197,11 +252,15 @@ def pack(self, loop_indices=None):
197
252
if self .view_index is None :
198
253
shape = shape + self .outer .shape [1 :]
199
254
200
- if self .access in {INC , WRITE }:
255
+ if self .init_with_zero :
256
+ also_zero = {MIN , MAX }
257
+ else :
258
+ also_zero = set ()
259
+ if self .access in {INC , WRITE } | also_zero :
201
260
val = Zero ((), self .outer .dtype )
202
261
multiindex = MultiIndex (* (Index (e ) for e in shape ))
203
262
self ._pack = Materialise (PackInst (), val , multiindex )
204
- elif self .access in {READ , RW , MIN , MAX }:
263
+ elif self .access in {READ , RW , MIN , MAX } - also_zero :
205
264
multiindex = MultiIndex (* (Index (e ) for e in shape ))
206
265
expr , mask = self ._rvalue (multiindex , loop_indices = loop_indices )
207
266
if mask is not None :
@@ -529,8 +588,9 @@ def emit_unpack_instruction(self, *,
529
588
530
589
class WrapperBuilder (object ):
531
590
532
- def __init__ (self , * , iterset , iteration_region = None , single_cell = False ,
591
+ def __init__ (self , * , kernel , iterset , iteration_region = None , single_cell = False ,
533
592
pass_layer_to_kernel = False , forward_arg_types = ()):
593
+ self .kernel = kernel
534
594
self .arguments = []
535
595
self .argument_accesses = []
536
596
self .packed_args = []
@@ -545,6 +605,10 @@ def __init__(self, *, iterset, iteration_region=None, single_cell=False,
545
605
self .single_cell = single_cell
546
606
self .forward_arguments = tuple (Argument ((), fa , pfx = "farg" ) for fa in forward_arg_types )
547
607
608
+ @property
609
+ def requires_zeroed_output_arguments (self ):
610
+ return self .kernel .requires_zeroed_output_arguments
611
+
548
612
@property
549
613
def subset (self ):
550
614
return isinstance (self .iterset , Subset )
@@ -557,9 +621,6 @@ def extruded(self):
557
621
def constant_layers (self ):
558
622
return self .extruded and self .iterset .constant_layers
559
623
560
- def set_kernel (self , kernel ):
561
- self .kernel = kernel
562
-
563
624
@cached_property
564
625
def loop_extents (self ):
565
626
return (Argument ((), IntType , name = "start" ),
@@ -674,7 +735,8 @@ def add_argument(self, arg):
674
735
shape = (None , * a .data .shape [1 :])
675
736
argument = Argument (shape , a .data .dtype , pfx = "mdat" )
676
737
packs .append (a .data .pack (argument , arg .access , self .map_ (a .map , unroll = a .unroll_map ),
677
- interior_horizontal = interior_horizontal ))
738
+ interior_horizontal = interior_horizontal ,
739
+ init_with_zero = self .requires_zeroed_output_arguments ))
678
740
self .arguments .append (argument )
679
741
pack = MixedDatPack (packs , arg .access , arg .dtype , interior_horizontal = interior_horizontal )
680
742
self .packed_args .append (pack )
@@ -692,15 +754,17 @@ def add_argument(self, arg):
692
754
pfx = "dat" )
693
755
pack = arg .data .pack (argument , arg .access , self .map_ (arg .map , unroll = arg .unroll_map ),
694
756
interior_horizontal = interior_horizontal ,
695
- view_index = view_index )
757
+ view_index = view_index ,
758
+ init_with_zero = self .requires_zeroed_output_arguments )
696
759
self .arguments .append (argument )
697
760
self .packed_args .append (pack )
698
761
self .argument_accesses .append (arg .access )
699
762
elif arg ._is_global :
700
763
argument = Argument (arg .data .dim ,
701
764
arg .data .dtype ,
702
765
pfx = "glob" )
703
- pack = GlobalPack (argument , arg .access )
766
+ pack = GlobalPack (argument , arg .access ,
767
+ init_with_zero = self .requires_zeroed_output_arguments )
704
768
self .arguments .append (argument )
705
769
self .packed_args .append (pack )
706
770
self .argument_accesses .append (arg .access )
0 commit comments