Skip to content

Commit fdf3fad

Browse files
wence-sv2518
authored andcommitted
codegen: New init_with_zero option for packs
Used to handle new Kernel requirement when the Kernel expects output arguments to be zero on entry. Fixes firedrakeproject/firedrake#1768.
1 parent 7b0fc5b commit fdf3fad

File tree

2 files changed

+33
-15
lines changed

2 files changed

+33
-15
lines changed

pyop2/codegen/builder.py

+29-13
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,10 @@ def emit_unpack_instruction(self, *, loop_indices=None):
145145

146146
class GlobalPack(Pack):
147147

148-
def __init__(self, outer, access):
148+
def __init__(self, outer, access, init_with_zero=False):
149149
self.outer = outer
150150
self.access = access
151+
self.init_with_zero = init_with_zero
151152

152153
def kernel_arg(self, loop_indices=None):
153154
pack = self.pack(loop_indices)
@@ -169,11 +170,15 @@ def pack(self, loop_indices=None):
169170
# vectorisation loop transformations privatise these reduction
170171
# variables. The extra memory movement cost is minimal.
171172
loop_indices = self.pick_loop_indices(*loop_indices)
172-
if self.access in {INC, WRITE}:
173+
if self.init_with_zero:
174+
also_zero = {MIN, MAX}
175+
else:
176+
also_zero = set()
177+
if self.access in {INC, WRITE} | also_zero:
173178
val = Zero((), self.outer.dtype)
174179
multiindex = MultiIndex(*(Index(e) for e in shape))
175180
self._pack = Materialise(PackInst(loop_indices), val, multiindex)
176-
elif self.access in {READ, RW, MIN, MAX}:
181+
elif self.access in {READ, RW, MIN, MAX} - also_zero:
177182
multiindex = MultiIndex(*(Index(e) for e in shape))
178183
expr = Indexed(self.outer, multiindex)
179184
self._pack = Materialise(PackInst(loop_indices), expr, multiindex)
@@ -203,13 +208,15 @@ def emit_unpack_instruction(self, *, loop_indices=None):
203208

204209
class DatPack(Pack):
205210
def __init__(self, outer, access, map_=None, interior_horizontal=False,
206-
view_index=None, layer_bounds=None):
211+
view_index=None, layer_bounds=None,
212+
init_with_zero=False):
207213
self.outer = outer
208214
self.map_ = map_
209215
self.access = access
210216
self.interior_horizontal = interior_horizontal
211217
self.view_index = view_index
212218
self.layer_bounds = layer_bounds
219+
self.init_with_zero = init_with_zero
213220

214221
def _mask(self, map_):
215222
"""Override this if the map_ needs a masking condition."""
@@ -245,11 +252,15 @@ def pack(self, loop_indices=None):
245252
if self.view_index is None:
246253
shape = shape + self.outer.shape[1:]
247254

248-
if self.access in {INC, WRITE}:
255+
if self.init_with_zero:
256+
also_zero = {MIN, MAX}
257+
else:
258+
also_zero = set()
259+
if self.access in {INC, WRITE} | also_zero:
249260
val = Zero((), self.outer.dtype)
250261
multiindex = MultiIndex(*(Index(e) for e in shape))
251262
self._pack = Materialise(PackInst(), val, multiindex)
252-
elif self.access in {READ, RW, MIN, MAX}:
263+
elif self.access in {READ, RW, MIN, MAX} - also_zero:
253264
multiindex = MultiIndex(*(Index(e) for e in shape))
254265
expr, mask = self._rvalue(multiindex, loop_indices=loop_indices)
255266
if mask is not None:
@@ -577,8 +588,9 @@ def emit_unpack_instruction(self, *,
577588

578589
class WrapperBuilder(object):
579590

580-
def __init__(self, *, iterset, iteration_region=None, single_cell=False,
591+
def __init__(self, *, kernel, iterset, iteration_region=None, single_cell=False,
581592
pass_layer_to_kernel=False, forward_arg_types=()):
593+
self.kernel = kernel
582594
self.arguments = []
583595
self.argument_accesses = []
584596
self.packed_args = []
@@ -593,6 +605,10 @@ def __init__(self, *, iterset, iteration_region=None, single_cell=False,
593605
self.single_cell = single_cell
594606
self.forward_arguments = tuple(Argument((), fa, pfx="farg") for fa in forward_arg_types)
595607

608+
@property
609+
def requires_zeroed_output_arguments(self):
610+
return self.kernel.requires_zeroed_output_arguments
611+
596612
@property
597613
def subset(self):
598614
return isinstance(self.iterset, Subset)
@@ -605,9 +621,6 @@ def extruded(self):
605621
def constant_layers(self):
606622
return self.extruded and self.iterset.constant_layers
607623

608-
def set_kernel(self, kernel):
609-
self.kernel = kernel
610-
611624
@cached_property
612625
def loop_extents(self):
613626
return (Argument((), IntType, name="start"),
@@ -722,7 +735,8 @@ def add_argument(self, arg):
722735
shape = (None, *a.data.shape[1:])
723736
argument = Argument(shape, a.data.dtype, pfx="mdat")
724737
packs.append(a.data.pack(argument, arg.access, self.map_(a.map, unroll=a.unroll_map),
725-
interior_horizontal=interior_horizontal))
738+
interior_horizontal=interior_horizontal,
739+
init_with_zero=self.requires_zeroed_output_arguments))
726740
self.arguments.append(argument)
727741
pack = MixedDatPack(packs, arg.access, arg.dtype, interior_horizontal=interior_horizontal)
728742
self.packed_args.append(pack)
@@ -740,15 +754,17 @@ def add_argument(self, arg):
740754
pfx="dat")
741755
pack = arg.data.pack(argument, arg.access, self.map_(arg.map, unroll=arg.unroll_map),
742756
interior_horizontal=interior_horizontal,
743-
view_index=view_index)
757+
view_index=view_index,
758+
init_with_zero=self.requires_zeroed_output_arguments)
744759
self.arguments.append(argument)
745760
self.packed_args.append(pack)
746761
self.argument_accesses.append(arg.access)
747762
elif arg._is_global:
748763
argument = Argument(arg.data.dim,
749764
arg.data.dtype,
750765
pfx="glob")
751-
pack = GlobalPack(argument, arg.access)
766+
pack = GlobalPack(argument, arg.access,
767+
init_with_zero=self.requires_zeroed_output_arguments)
752768
self.arguments.append(argument)
753769
self.packed_args.append(pack)
754770
self.argument_accesses.append(arg.access)

pyop2/sequential.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -148,10 +148,12 @@ def code_to_compile(self):
148148
from pyop2.codegen.builder import WrapperBuilder
149149
from pyop2.codegen.rep2loopy import generate
150150

151-
builder = WrapperBuilder(iterset=self._iterset, iteration_region=self._iteration_region, pass_layer_to_kernel=self._pass_layer_arg)
151+
builder = WrapperBuilder(kernel=self._kernel,
152+
iterset=self._iterset,
153+
iteration_region=self._iteration_region,
154+
pass_layer_to_kernel=self._pass_layer_arg)
152155
for arg in self._args:
153156
builder.add_argument(arg)
154-
builder.set_kernel(self._kernel)
155157

156158
wrapper = generate(builder)
157159
if self._iterset._extruded:

0 commit comments

Comments
 (0)