Skip to content

Commit 32c565b

Browse files
wence-sv2518
authored andcommitted
codegen: Create packs for Global args
This is necessary so that vectorising over the outer loop correctly privatises the accumulation variables.
1 parent 72d3ff0 commit 32c565b

File tree

1 file changed

+42
-9
lines changed

1 file changed

+42
-9
lines changed

pyop2/codegen/builder.py

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -150,22 +150,55 @@ def __init__(self, outer, access):
150150
self.access = access
151151

152152
def kernel_arg(self, loop_indices=None):
153-
return Indexed(self.outer, (Index(e) for e in self.outer.shape))
153+
pack = self.pack(loop_indices)
154+
return Indexed(pack, (Index(e) for e in pack.shape))
154155

155156
def emit_pack_instruction(self, *, loop_indices=None):
157+
return ()
158+
159+
def pack(self, loop_indices=None):
160+
if hasattr(self, "_pack"):
161+
return self._pack
162+
156163
shape = self.outer.shape
157-
if self.access is WRITE:
158-
zero = Zero((), self.outer.dtype)
164+
if self.access is READ:
165+
# No packing required
166+
return self.outer
167+
# We don't need to pack for memory layout, however packing
168+
# globals that are written is required such that subsequent
169+
# vectorisation loop transformations privatise these reduction
170+
# variables. The extra memory movement cost is minimal.
171+
loop_indices = self.pick_loop_indices(*loop_indices)
172+
if self.access in {INC, WRITE}:
173+
val = Zero((), self.outer.dtype)
174+
multiindex = MultiIndex(*(Index(e) for e in shape))
175+
self._pack = Materialise(PackInst(loop_indices), val, multiindex)
176+
elif self.access in {READ, RW, MIN, MAX}:
159177
multiindex = MultiIndex(*(Index(e) for e in shape))
160-
yield Accumulate(PackInst(), Indexed(self.outer, multiindex), zero)
178+
expr = Indexed(self.outer, multiindex)
179+
self._pack = Materialise(PackInst(loop_indices), expr, multiindex)
161180
else:
162-
return ()
163-
164-
def pack(self, loop_indices=None):
165-
return None
181+
raise ValueError("Don't know how to initialise pack for '%s' access" % self.access)
182+
return self._pack
166183

167184
def emit_unpack_instruction(self, *, loop_indices=None):
168-
return ()
185+
pack = self.pack(loop_indices)
186+
loop_indices = self.pick_loop_indices(*loop_indices)
187+
if pack is None:
188+
return ()
189+
elif self.access is READ:
190+
return ()
191+
elif self.access in {INC, MIN, MAX}:
192+
op = {INC: Sum,
193+
MIN: Min,
194+
MAX: Max}[self.access]
195+
multiindex = tuple(Index(e) for e in pack.shape)
196+
rvalue = Indexed(self.outer, multiindex)
197+
yield Accumulate(UnpackInst(loop_indices), rvalue, op(rvalue, Indexed(pack, multiindex)))
198+
else:
199+
multiindex = tuple(Index(e) for e in pack.shape)
200+
rvalue = Indexed(self.outer, multiindex)
201+
yield Accumulate(UnpackInst(loop_indices), rvalue, Indexed(pack, multiindex))
169202

170203

171204
class DatPack(Pack):

0 commit comments

Comments
 (0)