Skip to content

Commit e0a4d3a

Browse files
Passthrough params (#708)
Pass objects to local kernels without packing and unpacking. --------- Co-authored-by: Connor Ward <[email protected]>
1 parent ad0c430 commit e0a4d3a

File tree

7 files changed

+170
-17
lines changed

7 files changed

+170
-17
lines changed

pyop2/codegen/builder.py

+29-11
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@
44
from functools import reduce
55

66
import numpy
7-
from loopy.types import OpaqueType
87
from pyop2.global_kernel import (GlobalKernelArg, DatKernelArg, MixedDatKernelArg,
9-
MatKernelArg, MixedMatKernelArg, PermutedMapKernelArg, ComposedMapKernelArg)
8+
MatKernelArg, MixedMatKernelArg, PermutedMapKernelArg, ComposedMapKernelArg, PassthroughKernelArg)
109
from pyop2.codegen.representation import (Accumulate, Argument, Comparison, Conditional,
1110
DummyInstruction, Extent, FixedIndex,
1211
FunctionCall, Index, Indexed,
@@ -16,16 +15,13 @@
1615
PreUnpackInst, Product, RuntimeIndex,
1716
Sum, Symbol, UnpackInst, Variable,
1817
When, Zero)
19-
from pyop2.datatypes import IntType
18+
from pyop2.datatypes import IntType, OpaqueType
2019
from pyop2.op2 import (ALL, INC, MAX, MIN, ON_BOTTOM, ON_INTERIOR_FACETS,
2120
ON_TOP, READ, RW, WRITE)
2221
from pyop2.utils import cached_property
2322

2423

25-
class PetscMat(OpaqueType):
26-
27-
def __init__(self):
28-
super().__init__(name="Mat")
24+
MatType = OpaqueType("Mat")
2925

3026

3127
def _Remainder(a, b):
@@ -226,6 +222,23 @@ def emit_unpack_instruction(self, *, loop_indices=None):
226222
"""Either yield an instruction, or else return an empty tuple (to indicate no instruction)"""
227223

228224

225+
class PassthroughPack(Pack):
226+
def __init__(self, outer):
227+
self.outer = outer
228+
229+
def kernel_arg(self, loop_indices=None):
230+
return self.outer
231+
232+
def pack(self, loop_indices=None):
233+
pass
234+
235+
def emit_pack_instruction(self, **kwargs):
236+
return ()
237+
238+
def emit_unpack_instruction(self, **kwargs):
239+
return ()
240+
241+
229242
class GlobalPack(Pack):
230243

231244
def __init__(self, outer, access, init_with_zero=False):
@@ -813,7 +826,12 @@ def add_argument(self, arg):
813826
dtype = local_arg.dtype
814827
interior_horizontal = self.iteration_region == ON_INTERIOR_FACETS
815828

816-
if isinstance(arg, GlobalKernelArg):
829+
if isinstance(arg, PassthroughKernelArg):
830+
argument = Argument((), dtype, pfx="arg")
831+
pack = PassthroughPack(argument)
832+
self.arguments.append(argument)
833+
834+
elif isinstance(arg, GlobalKernelArg):
817835
argument = Argument(arg.dim, dtype, pfx="glob")
818836

819837
pack = GlobalPack(argument, access,
@@ -856,7 +874,7 @@ def add_argument(self, arg):
856874
pack = MixedDatPack(packs, access, dtype,
857875
interior_horizontal=interior_horizontal)
858876
elif isinstance(arg, MatKernelArg):
859-
argument = Argument((), PetscMat(), pfx="mat")
877+
argument = Argument((), MatType, pfx="mat")
860878
maps = tuple(self._add_map(m, arg.unroll)
861879
for m in arg.maps)
862880
pack = arg.pack(argument, access, maps,
@@ -866,7 +884,7 @@ def add_argument(self, arg):
866884
elif isinstance(arg, MixedMatKernelArg):
867885
packs = []
868886
for a in arg:
869-
argument = Argument((), PetscMat(), pfx="mat")
887+
argument = Argument((), MatType, pfx="mat")
870888
maps = tuple(self._add_map(m, a.unroll)
871889
for m in a.maps)
872890

@@ -949,7 +967,7 @@ def kernel_call(self):
949967
args = self.kernel_args
950968
access = tuple(self.loopy_argument_accesses)
951969
# assuming every index is free index
952-
free_indices = set(itertools.chain.from_iterable(arg.multiindex for arg in args))
970+
free_indices = set(itertools.chain.from_iterable(arg.multiindex for arg in args if isinstance(arg, Indexed)))
953971
# remove runtime index
954972
free_indices = tuple(i for i in free_indices if isinstance(i, Index))
955973
if self.pass_layer_to_kernel:

pyop2/codegen/representation.py

-2
Original file line numberDiff line numberDiff line change
@@ -352,8 +352,6 @@ def __new__(cls, aggregate, multiindex):
352352
for index, extent in zip(multiindex, aggregate.shape):
353353
if isinstance(index, Index):
354354
index.set_extent(extent)
355-
if not multiindex:
356-
return aggregate
357355

358356
self = super().__new__(cls)
359357
self.children = (aggregate, multiindex)

pyop2/datatypes.py

+8
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,11 @@ def dtype_limits(dtype):
6969
except ValueError as e:
7070
raise ValueError("Unable to determine numeric limits from %s" % dtype) from e
7171
return info.min, info.max
72+
73+
74+
class OpaqueType(lp.types.OpaqueType):
75+
def __init__(self, name):
76+
super().__init__(name=name)
77+
78+
def __repr__(self):
79+
return self.name

pyop2/global_kernel.py

+10
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,16 @@ def pack(self):
206206
return DatPack
207207

208208

209+
class PassthroughKernelArg:
210+
@property
211+
def cache_key(self):
212+
return type(self)
213+
214+
@property
215+
def maps(self):
216+
return ()
217+
218+
209219
@dataclass(frozen=True)
210220
class MixedMatKernelArg:
211221
"""Class representing a :class:`pyop2.types.MixedDat` being passed to the kernel.

pyop2/op2.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import atexit
3737

3838
from pyop2.configuration import configuration
39+
from pyop2.datatypes import OpaqueType # noqa: F401
3940
from pyop2.logger import debug, info, warning, error, critical, set_log_level
4041
from pyop2.mpi import MPI, COMM_WORLD, collective
4142

@@ -52,7 +53,7 @@
5253
from pyop2.global_kernel import (GlobalKernelArg, DatKernelArg, MixedDatKernelArg, # noqa: F401
5354
MatKernelArg, MixedMatKernelArg, MapKernelArg, GlobalKernel)
5455
from pyop2.parloop import (GlobalParloopArg, DatParloopArg, MixedDatParloopArg, # noqa: F401
55-
MatParloopArg, MixedMatParloopArg, Parloop, parloop, par_loop)
56+
MatParloopArg, MixedMatParloopArg, PassthroughArg, Parloop, parloop, par_loop)
5657
from pyop2.parloop import (GlobalLegacyArg, DatLegacyArg, MixedDatLegacyArg, # noqa: F401
5758
MatLegacyArg, MixedMatLegacyArg, LegacyParloop, ParLoop)
5859

pyop2/parloop.py

+85-3
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from pyop2.datatypes import as_numpy_dtype
1414
from pyop2.exceptions import KernelTypeError, MapValueError, SetTypeError
1515
from pyop2.global_kernel import (GlobalKernelArg, DatKernelArg, MixedDatKernelArg,
16-
MatKernelArg, MixedMatKernelArg, GlobalKernel)
16+
MatKernelArg, MixedMatKernelArg, PassthroughKernelArg, GlobalKernel)
1717
from pyop2.local_kernel import LocalKernel, CStringLocalKernel, LoopyLocalKernel
1818
from pyop2.types import (Access, Global, AbstractDat, Dat, DatView, MixedDat, Mat, Set,
1919
MixedSet, ExtrudedSet, Subset, Map, ComposedMap, MixedMap)
@@ -39,6 +39,10 @@ class GlobalParloopArg(ParloopArg):
3939

4040
data: Global
4141

42+
@property
43+
def _kernel_args_(self):
44+
return self.data._kernel_args_
45+
4246
@property
4347
def map_kernel_args(self):
4448
return ()
@@ -59,6 +63,10 @@ def __post_init__(self):
5963
if self.map_ is not None:
6064
self.check_map(self.map_)
6165

66+
@property
67+
def _kernel_args_(self):
68+
return self.data._kernel_args_
69+
6270
@property
6371
def map_kernel_args(self):
6472
return self.map_._kernel_args_ if self.map_ else ()
@@ -81,6 +89,10 @@ class MixedDatParloopArg(ParloopArg):
8189
def __post_init__(self):
8290
self.check_map(self.map_)
8391

92+
@property
93+
def _kernel_args_(self):
94+
return self.data._kernel_args_
95+
8496
@property
8597
def map_kernel_args(self):
8698
return self.map_._kernel_args_ if self.map_ else ()
@@ -102,6 +114,10 @@ def __post_init__(self):
102114
for m in self.maps:
103115
self.check_map(m)
104116

117+
@property
118+
def _kernel_args_(self):
119+
return self.data._kernel_args_
120+
105121
@property
106122
def map_kernel_args(self):
107123
rmap, cmap = self.maps
@@ -120,12 +136,34 @@ def __post_init__(self):
120136
for m in self.maps:
121137
self.check_map(m)
122138

139+
@property
140+
def _kernel_args_(self):
141+
return self.data._kernel_args_
142+
123143
@property
124144
def map_kernel_args(self):
125145
rmap, cmap = self.maps
126146
return tuple(itertools.chain(*itertools.product(rmap._kernel_args_, cmap._kernel_args_)))
127147

128148

149+
@dataclass
150+
class PassthroughParloopArg(ParloopArg):
151+
# a pointer
152+
data: int
153+
154+
@property
155+
def _kernel_args_(self):
156+
return (self.data,)
157+
158+
@property
159+
def map_kernel_args(self):
160+
return ()
161+
162+
@property
163+
def maps(self):
164+
return ()
165+
166+
129167
class Parloop:
130168
"""A parallel loop invocation.
131169
@@ -167,7 +205,7 @@ def arglist(self):
167205
"""Prepare the argument list for calling generated code."""
168206
arglist = self.iterset._kernel_args_
169207
for d in self.arguments:
170-
arglist += d.data._kernel_args_
208+
arglist += d._kernel_args_
171209

172210
# Collect an ordered set of maps (ignore duplicates)
173211
maps = {m: None for d in self.arguments for m in d.map_kernel_args}
@@ -224,6 +262,8 @@ def __call__(self):
224262
def increment_dat_version(self):
225263
"""Increment dat versions of :class:`DataCarrier`s in the arguments."""
226264
for lk_arg, gk_arg, pl_arg in self.zipped_arguments:
265+
if isinstance(pl_arg, PassthroughParloopArg):
266+
continue
227267
assert isinstance(pl_arg.data, DataCarrier)
228268
if lk_arg.access is not Access.READ:
229269
if pl_arg.data in self.reduced_globals:
@@ -520,6 +560,10 @@ class GlobalLegacyArg(LegacyArg):
520560
data: Global
521561
access: Access
522562

563+
@property
564+
def dtype(self):
565+
return self.data.dtype
566+
523567
@property
524568
def global_kernel_arg(self):
525569
return GlobalKernelArg(self.data.dim)
@@ -537,6 +581,10 @@ class DatLegacyArg(LegacyArg):
537581
map_: Optional[Map]
538582
access: Access
539583

584+
@property
585+
def dtype(self):
586+
return self.data.dtype
587+
540588
@property
541589
def global_kernel_arg(self):
542590
map_arg = self.map_._global_kernel_arg if self.map_ is not None else None
@@ -556,6 +604,10 @@ class MixedDatLegacyArg(LegacyArg):
556604
map_: MixedMap
557605
access: Access
558606

607+
@property
608+
def dtype(self):
609+
return self.data.dtype
610+
559611
@property
560612
def global_kernel_arg(self):
561613
args = []
@@ -579,6 +631,10 @@ class MatLegacyArg(LegacyArg):
579631
lgmaps: Optional[Tuple[Any, Any]] = None
580632
needs_unrolling: Optional[bool] = False
581633

634+
@property
635+
def dtype(self):
636+
return self.data.dtype
637+
582638
@property
583639
def global_kernel_arg(self):
584640
map_args = [m._global_kernel_arg for m in self.maps]
@@ -599,6 +655,10 @@ class MixedMatLegacyArg(LegacyArg):
599655
lgmaps: Tuple[Any] = None
600656
needs_unrolling: Optional[bool] = False
601657

658+
@property
659+
def dtype(self):
660+
return self.data.dtype
661+
602662
@property
603663
def global_kernel_arg(self):
604664
nrows, ncols = self.data.sparsity.shape
@@ -618,6 +678,28 @@ def parloop_arg(self):
618678
return MixedMatParloopArg(self.data, tuple(self.maps), self.lgmaps)
619679

620680

681+
@dataclass
682+
class PassthroughArg(LegacyArg):
683+
"""Argument that is simply passed to the local kernel without packing.
684+
685+
:param dtype: The datatype of the argument. This is needed for code generation.
686+
:param data: A pointer to the data.
687+
"""
688+
# We don't know what the local kernel is doing with this argument
689+
access = Access.RW
690+
691+
dtype: Any
692+
data: int
693+
694+
@property
695+
def global_kernel_arg(self):
696+
return PassthroughKernelArg()
697+
698+
@property
699+
def parloop_arg(self):
700+
return PassthroughParloopArg(self.data)
701+
702+
621703
def ParLoop(*args, **kwargs):
622704
return LegacyParloop(*args, **kwargs)
623705

@@ -641,7 +723,7 @@ def LegacyParloop(local_knl, iterset, *args, **kwargs):
641723
# finish building the local kernel
642724
local_knl.accesses = tuple(a.access for a in args)
643725
if isinstance(local_knl, CStringLocalKernel):
644-
local_knl.dtypes = tuple(a.data.dtype for a in args)
726+
local_knl.dtypes = tuple(a.dtype for a in args)
645727

646728
global_knl_args = tuple(a.global_kernel_arg for a in args)
647729
extruded = iterset._extruded

test/unit/test_direct_loop.py

+36
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535
import pytest
3636
import numpy as np
37+
from petsc4py import PETSc
3738

3839
from pyop2 import op2
3940
from pyop2.exceptions import MapValueError
@@ -249,6 +250,41 @@ def test_kernel_cplusplus(self, delems):
249250

250251
assert (y.data == 10.5).all()
251252

253+
def test_passthrough_mat(self):
254+
niters = 10
255+
iterset = op2.Set(niters)
256+
257+
c_kernel = """
258+
static void mat_inc(Mat mat) {
259+
PetscScalar values[] = {1, 2, 3, 4, 5, 6, 7, 8, 9};
260+
PetscInt idxs[] = {0, 2, 4};
261+
MatSetValues(mat, 3, idxs, 3, idxs, values, ADD_VALUES);
262+
}
263+
"""
264+
kernel = op2.Kernel(c_kernel, "mat_inc")
265+
266+
# create a tiny 5x5 sparse matrix
267+
petsc_mat = PETSc.Mat().create()
268+
petsc_mat.setSizes(5)
269+
petsc_mat.setUp()
270+
petsc_mat.setValues([0, 2, 4], [0, 2, 4], np.zeros((3, 3), dtype=PETSc.ScalarType))
271+
petsc_mat.assemble()
272+
273+
arg = op2.PassthroughArg(op2.OpaqueType("Mat"), petsc_mat.handle)
274+
op2.par_loop(kernel, iterset, arg)
275+
petsc_mat.assemble()
276+
277+
assert np.allclose(
278+
petsc_mat.getValues(range(5), range(5)),
279+
[
280+
[10, 0, 20, 0, 30],
281+
[0]*5,
282+
[40, 0, 50, 0, 60],
283+
[0]*5,
284+
[70, 0, 80, 0, 90],
285+
]
286+
)
287+
252288

253289
if __name__ == '__main__':
254290
import os

0 commit comments

Comments
 (0)