Skip to content

Commit e7a638b

Browse files
committed
Merge remote-tracking branch 'remotes/origin/wence/feature/codegen-updates' into vectorisation
2 parents c639342 + 11d558f commit e7a638b

File tree

5 files changed

+119
-49
lines changed

5 files changed

+119
-49
lines changed

pyop2/base.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3287,6 +3287,8 @@ class Kernel(Cached):
32873287
empty)
32883288
:param ldargs: A list of arguments to pass to the linker when
32893289
compiling this Kernel.
3290+
:param requires_zeroed_output_arguments: Does this kernel require the
3291+
output arguments to be zeroed on entry when called? (default no)
32903292
:param cpp: Is the kernel actually C++ rather than C? If yes,
32913293
then compile with the C++ compiler (kernel is wrapped in
32923294
extern C for linkage reasons).
@@ -3309,7 +3311,7 @@ class Kernel(Cached):
33093311
@classmethod
33103312
@validate_type(('name', str, NameTypeError))
33113313
def _cache_key(cls, code, name, opts={}, include_dirs=[], headers=[],
3312-
user_code="", ldargs=None, cpp=False):
3314+
user_code="", ldargs=None, cpp=False, requires_zeroed_output_arguments=False):
33133315
# Both code and name are relevant since there might be multiple kernels
33143316
# extracting different functions from the same code
33153317
# Also include the PyOP2 version, since the Kernel class might change
@@ -3323,15 +3325,15 @@ def _cache_key(cls, code, name, opts={}, include_dirs=[], headers=[],
33233325
code.update_persistent_hash(key_hash, LoopyKeyBuilder())
33243326
code = key_hash.hexdigest()
33253327
hashee = (str(code) + name + str(sorted(opts.items())) + str(include_dirs)
3326-
+ str(headers) + version + str(ldargs) + str(cpp))
3328+
+ str(headers) + version + str(ldargs) + str(cpp) + str(requires_zeroed_output_arguments))
33273329
return md5(hashee.encode()).hexdigest()
33283330

33293331
@cached_property
33303332
def _wrapper_cache_key_(self):
33313333
return (self._key, )
33323334

33333335
def __init__(self, code, name, opts={}, include_dirs=[], headers=[],
3334-
user_code="", ldargs=None, cpp=False):
3336+
user_code="", ldargs=None, cpp=False, requires_zeroed_output_arguments=False):
33353337
# Protect against re-initialization when retrieved from cache
33363338
if self._initialized:
33373339
return
@@ -3346,6 +3348,7 @@ def __init__(self, code, name, opts={}, include_dirs=[], headers=[],
33463348
assert isinstance(code, (str, Node, loopy.Program, loopy.LoopKernel))
33473349
self._code = code
33483350
self._initialized = True
3351+
self.requires_zeroed_output_arguments = requires_zeroed_output_arguments
33493352

33503353
@property
33513354
def name(self):

pyop2/codegen/builder.py

Lines changed: 94 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,10 @@ def __init__(self, map_, interior_horizontal, layer_bounds,
5151
shape = (None, ) + map_.shape[1:]
5252
values = Argument(shape, dtype=map_.dtype, pfx="map")
5353
if offset is not None:
54-
offset = NamedLiteral(offset, name=values.name + "_offset")
54+
if len(set(map_.offset)) == 1:
55+
offset = Literal(offset[0], casting=True)
56+
else:
57+
offset = NamedLiteral(offset, name=values.name + "_offset")
5558

5659
self.values = values
5760
self.offset = offset
@@ -68,21 +71,33 @@ def indexed(self, multiindex, layer=None):
6871
n, i, f = multiindex
6972
if layer is not None and self.offset is not None:
7073
# For extruded mesh, prefetch the indirections for each map, so that they don't
71-
# need to be recomputed. Different f values need to be treated separately.
74+
# need to be recomputed.
75+
# First prefetch the base map (not dependent on layers)
76+
base_key = None
77+
if base_key not in self.prefetch:
78+
j = Index()
79+
base = Indexed(self.values, (n, j))
80+
self.prefetch[base_key] = Materialise(PackInst(), base, MultiIndex(j))
81+
82+
base = self.prefetch[base_key]
83+
84+
# Now prefetch the extruded part of the map (inside the layer loop).
85+
# This is necessary so loopy DTRT for MatSetValues
86+
# Different f values need to be treated separately.
7287
key = f.extent
7388
if key is None:
7489
key = 1
7590
if key not in self.prefetch:
7691
bottom_layer, _ = self.layer_bounds
77-
offset_extent, = self.offset.shape
78-
j = Index(offset_extent)
79-
base = Indexed(self.values, (n, j))
80-
if f.extent:
81-
k = Index(f.extent)
82-
else:
83-
k = Index(1)
92+
k = Index(f.extent if f.extent is not None else 1)
8493
offset = Sum(Sum(layer, Product(Literal(numpy.int32(-1)), bottom_layer)), k)
85-
offset = Product(offset, Indexed(self.offset, (j,)))
94+
j = Index()
95+
# Inline map offsets where all entries are identical.
96+
if self.offset.shape == ():
97+
offset = Product(offset, self.offset)
98+
else:
99+
offset = Product(offset, Indexed(self.offset, (j,)))
100+
base = Indexed(base, (j, ))
86101
self.prefetch[key] = Materialise(PackInst(), Sum(base, offset), MultiIndex(k, j))
87102

88103
return Indexed(self.prefetch[key], (f, i)), (f, i)
@@ -130,38 +145,78 @@ def emit_unpack_instruction(self, *, loop_indices=None):
130145

131146
class GlobalPack(Pack):
132147

133-
def __init__(self, outer, access):
148+
def __init__(self, outer, access, init_with_zero=False):
134149
self.outer = outer
135150
self.access = access
151+
self.init_with_zero = init_with_zero
136152

137153
def kernel_arg(self, loop_indices=None):
138-
return Indexed(self.outer, (Index(e) for e in self.outer.shape))
154+
pack = self.pack(loop_indices)
155+
return Indexed(pack, (Index(e) for e in pack.shape))
139156

140157
def emit_pack_instruction(self, *, loop_indices=None):
158+
return ()
159+
160+
def pack(self, loop_indices=None):
161+
if hasattr(self, "_pack"):
162+
return self._pack
163+
141164
shape = self.outer.shape
142-
if self.access is WRITE:
143-
zero = Zero((), self.outer.dtype)
165+
if self.access is READ:
166+
# No packing required
167+
return self.outer
168+
# We don't need to pack for memory layout, however packing
169+
# globals that are written is required such that subsequent
170+
# vectorisation loop transformations privatise these reduction
171+
# variables. The extra memory movement cost is minimal.
172+
loop_indices = self.pick_loop_indices(*loop_indices)
173+
if self.init_with_zero:
174+
also_zero = {MIN, MAX}
175+
else:
176+
also_zero = set()
177+
if self.access in {INC, WRITE} | also_zero:
178+
val = Zero((), self.outer.dtype)
179+
multiindex = MultiIndex(*(Index(e) for e in shape))
180+
self._pack = Materialise(PackInst(loop_indices), val, multiindex)
181+
elif self.access in {READ, RW, MIN, MAX} - also_zero:
144182
multiindex = MultiIndex(*(Index(e) for e in shape))
145-
yield Accumulate(PackInst(), Indexed(self.outer, multiindex), zero)
183+
expr = Indexed(self.outer, multiindex)
184+
self._pack = Materialise(PackInst(loop_indices), expr, multiindex)
146185
else:
147-
return ()
148-
149-
def pack(self, loop_indices=None):
150-
return None
186+
raise ValueError("Don't know how to initialise pack for '%s' access" % self.access)
187+
return self._pack
151188

152189
def emit_unpack_instruction(self, *, loop_indices=None):
153-
return ()
190+
pack = self.pack(loop_indices)
191+
loop_indices = self.pick_loop_indices(*loop_indices)
192+
if pack is None:
193+
return ()
194+
elif self.access is READ:
195+
return ()
196+
elif self.access in {INC, MIN, MAX}:
197+
op = {INC: Sum,
198+
MIN: Min,
199+
MAX: Max}[self.access]
200+
multiindex = tuple(Index(e) for e in pack.shape)
201+
rvalue = Indexed(self.outer, multiindex)
202+
yield Accumulate(UnpackInst(loop_indices), rvalue, op(rvalue, Indexed(pack, multiindex)))
203+
else:
204+
multiindex = tuple(Index(e) for e in pack.shape)
205+
rvalue = Indexed(self.outer, multiindex)
206+
yield Accumulate(UnpackInst(loop_indices), rvalue, Indexed(pack, multiindex))
154207

155208

156209
class DatPack(Pack):
157210
def __init__(self, outer, access, map_=None, interior_horizontal=False,
158-
view_index=None, layer_bounds=None):
211+
view_index=None, layer_bounds=None,
212+
init_with_zero=False):
159213
self.outer = outer
160214
self.map_ = map_
161215
self.access = access
162216
self.interior_horizontal = interior_horizontal
163217
self.view_index = view_index
164218
self.layer_bounds = layer_bounds
219+
self.init_with_zero = init_with_zero
165220

166221
def _mask(self, map_):
167222
"""Override this if the map_ needs a masking condition."""
@@ -197,11 +252,15 @@ def pack(self, loop_indices=None):
197252
if self.view_index is None:
198253
shape = shape + self.outer.shape[1:]
199254

200-
if self.access in {INC, WRITE}:
255+
if self.init_with_zero:
256+
also_zero = {MIN, MAX}
257+
else:
258+
also_zero = set()
259+
if self.access in {INC, WRITE} | also_zero:
201260
val = Zero((), self.outer.dtype)
202261
multiindex = MultiIndex(*(Index(e) for e in shape))
203262
self._pack = Materialise(PackInst(), val, multiindex)
204-
elif self.access in {READ, RW, MIN, MAX}:
263+
elif self.access in {READ, RW, MIN, MAX} - also_zero:
205264
multiindex = MultiIndex(*(Index(e) for e in shape))
206265
expr, mask = self._rvalue(multiindex, loop_indices=loop_indices)
207266
if mask is not None:
@@ -529,8 +588,9 @@ def emit_unpack_instruction(self, *,
529588

530589
class WrapperBuilder(object):
531590

532-
def __init__(self, *, iterset, iteration_region=None, single_cell=False,
591+
def __init__(self, *, kernel, iterset, iteration_region=None, single_cell=False,
533592
pass_layer_to_kernel=False, forward_arg_types=()):
593+
self.kernel = kernel
534594
self.arguments = []
535595
self.argument_accesses = []
536596
self.packed_args = []
@@ -545,6 +605,10 @@ def __init__(self, *, iterset, iteration_region=None, single_cell=False,
545605
self.single_cell = single_cell
546606
self.forward_arguments = tuple(Argument((), fa, pfx="farg") for fa in forward_arg_types)
547607

608+
@property
609+
def requires_zeroed_output_arguments(self):
610+
return self.kernel.requires_zeroed_output_arguments
611+
548612
@property
549613
def subset(self):
550614
return isinstance(self.iterset, Subset)
@@ -557,9 +621,6 @@ def extruded(self):
557621
def constant_layers(self):
558622
return self.extruded and self.iterset.constant_layers
559623

560-
def set_kernel(self, kernel):
561-
self.kernel = kernel
562-
563624
@cached_property
564625
def loop_extents(self):
565626
return (Argument((), IntType, name="start"),
@@ -674,7 +735,8 @@ def add_argument(self, arg):
674735
shape = (None, *a.data.shape[1:])
675736
argument = Argument(shape, a.data.dtype, pfx="mdat")
676737
packs.append(a.data.pack(argument, arg.access, self.map_(a.map, unroll=a.unroll_map),
677-
interior_horizontal=interior_horizontal))
738+
interior_horizontal=interior_horizontal,
739+
init_with_zero=self.requires_zeroed_output_arguments))
678740
self.arguments.append(argument)
679741
pack = MixedDatPack(packs, arg.access, arg.dtype, interior_horizontal=interior_horizontal)
680742
self.packed_args.append(pack)
@@ -692,15 +754,17 @@ def add_argument(self, arg):
692754
pfx="dat")
693755
pack = arg.data.pack(argument, arg.access, self.map_(arg.map, unroll=arg.unroll_map),
694756
interior_horizontal=interior_horizontal,
695-
view_index=view_index)
757+
view_index=view_index,
758+
init_with_zero=self.requires_zeroed_output_arguments)
696759
self.arguments.append(argument)
697760
self.packed_args.append(pack)
698761
self.argument_accesses.append(arg.access)
699762
elif arg._is_global:
700763
argument = Argument(arg.data.dim,
701764
arg.data.dtype,
702765
pfx="glob")
703-
pack = GlobalPack(argument, arg.access)
766+
pack = GlobalPack(argument, arg.access,
767+
init_with_zero=self.requires_zeroed_output_arguments)
704768
self.arguments.append(argument)
705769
self.packed_args.append(pack)
706770
self.argument_accesses.append(arg.access)

pyop2/codegen/rep2loopy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def loop_nesting(instructions, deps, outer_inames, kernel_name):
231231
if isinstance(insn.children[1], (Zero, Literal)):
232232
nesting[insn] = outer_inames
233233
else:
234-
nesting[insn] = runtime_indices([insn])
234+
nesting[insn] = runtime_indices([insn]) | runtime_indices(insn.label.within_inames)
235235
else:
236236
assert isinstance(insn, FunctionCall)
237237
if insn.name in (petsc_functions | {kernel_name}):
@@ -468,9 +468,9 @@ def generate(builder, wrapper_name=None):
468468
from coffee.base import Node
469469

470470
if isinstance(kernel._code, loopy.LoopKernel):
471+
from loopy.transform.callable import _match_caller_callee_argument_dimension_
471472
knl = kernel._code
472473
wrapper = loopy.register_callable_kernel(wrapper, knl)
473-
from loopy.transform.callable import _match_caller_callee_argument_dimension_
474474
wrapper = _match_caller_callee_argument_dimension_(wrapper, knl.name)
475475
else:
476476
# kernel is a string, add it to preamble

pyop2/codegen/representation.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010

1111
class InstructionLabel(object):
12-
pass
12+
def __init__(self, within_inames=()):
13+
self.within_inames = tuple(w for w in within_inames if isinstance(w, Node))
1314

1415

1516
class PackInst(InstructionLabel):
@@ -99,6 +100,8 @@ def set_extent(self, value):
99100
elif self.extent != value:
100101
raise ValueError("Inconsistent index extents")
101102

103+
dtype = numpy.int32
104+
102105

103106
class FixedIndex(Terminal, Scalar):
104107
__slots__ = ("value", )
@@ -108,7 +111,9 @@ class FixedIndex(Terminal, Scalar):
108111

109112
def __init__(self, value):
110113
assert isinstance(value, numbers.Integral)
111-
self.value = int(value)
114+
self.value = numpy.int32(value)
115+
116+
dtype = numpy.int32
112117

113118

114119
class RuntimeIndex(Scalar):
@@ -266,7 +271,7 @@ def __init__(self, a, b):
266271
@cached_property
267272
def dtype(self):
268273
a, b = self.children
269-
return a.dtype
274+
return numpy.find_common_type([], [a.dtype, b.dtype])
270275

271276

272277
class Sum(Scalar):
@@ -280,7 +285,7 @@ def __init__(self, a, b):
280285
@cached_property
281286
def dtype(self):
282287
a, b = self.children
283-
return a.dtype
288+
return numpy.find_common_type([], [a.dtype, b.dtype])
284289

285290

286291
class Product(Scalar):
@@ -294,7 +299,7 @@ def __init__(self, a, b):
294299
@cached_property
295300
def dtype(self):
296301
a, b = self.children
297-
return a.dtype
302+
return numpy.find_common_type([], [a.dtype, b.dtype])
298303

299304

300305
class Indexed(Scalar):
@@ -382,7 +387,7 @@ def __init__(self, name, shape, dtype):
382387

383388

384389
class DummyInstruction(Node):
385-
__slots__ = ("children",)
390+
__slots__ = ("children", "label")
386391
__front__ = ("label",)
387392

388393
def __init__(self, label, *children):
@@ -391,17 +396,13 @@ def __init__(self, label, *children):
391396

392397

393398
class Accumulate(Node):
394-
__slots__ = ("children",)
399+
__slots__ = ("children", "label")
395400
__front__ = ("label",)
396401

397402
def __init__(self, label, lvalue, rvalue):
398403
self.children = (lvalue, rvalue)
399404
self.label = label
400405

401-
def reconstruct(self, *args):
402-
new = type(self)(*self._cons_args(args))
403-
return new
404-
405406

406407
class FunctionCall(Node):
407408
__slots__ = ("name", "access", "free_indices", "label", "children")
@@ -417,7 +418,7 @@ def __init__(self, name, label, access, free_indices, *arguments):
417418

418419

419420
class Conditional(Scalar):
420-
__slots__ = ("children")
421+
__slots__ = ("children", )
421422

422423
def __init__(self, condition, then, else_):
423424
assert not condition.shape

pyop2/sequential.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,10 +148,12 @@ def code_to_compile(self):
148148
from pyop2.codegen.builder import WrapperBuilder
149149
from pyop2.codegen.rep2loopy import generate
150150

151-
builder = WrapperBuilder(iterset=self._iterset, iteration_region=self._iteration_region, pass_layer_to_kernel=self._pass_layer_arg)
151+
builder = WrapperBuilder(kernel=self._kernel,
152+
iterset=self._iterset,
153+
iteration_region=self._iteration_region,
154+
pass_layer_to_kernel=self._pass_layer_arg)
152155
for arg in self._args:
153156
builder.add_argument(arg)
154-
builder.set_kernel(self._kernel)
155157

156158
wrapper = generate(builder)
157159
if self._iterset._extruded:

0 commit comments

Comments
 (0)