Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 1381d88

Browse files
committedJul 11, 2022
implements backend-specific reduction_(begin|end)
1 parent 38e2027 commit 1381d88

File tree

4 files changed

+135
-96
lines changed

4 files changed

+135
-96
lines changed
 

‎pyop2/backends/cpu.py

Lines changed: 41 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from pyop2.types.map import Map, MixedMap
55
from pyop2.parloop import AbstractParloop
66
from pyop2.global_kernel import AbstractGlobalKernel
7+
from pyop2.types.access import INC, MIN, MAX
78
from pyop2.types.mat import Mat
89
from pyop2.types.glob import Global
910
from pyop2.backends import AbstractComputeBackend
@@ -23,14 +24,16 @@ class Dat(BaseDat):
2324
@utils.cached_property
2425
def _vec(self):
2526
assert self.dtype == PETSc.ScalarType, \
26-
"Can't create Vec with type %s, must be %s" % (self.dtype, PETSc.ScalarType)
27+
"Can't create Vec with type %s, must be %s" % (self.dtype,
28+
PETSc.ScalarType)
2729
# Can't duplicate layout_vec of dataset, because we then
2830
# carry around extra unnecessary data.
2931
# But use getSizes to save an Allreduce in computing the
3032
# global size.
3133
size = self.dataset.layout_vec.getSizes()
3234
data = self._data[:size[0]]
33-
vec = PETSc.Vec().createWithArray(data, size=size, bsize=self.cdim, comm=self.comm)
35+
vec = PETSc.Vec().createWithArray(data, size=size,
36+
bsize=self.cdim, comm=self.comm)
3437
return vec
3538

3639

@@ -46,9 +49,10 @@ def code_to_compile(self):
4649

4750
if self.local_kernel.cpp:
4851
from loopy.codegen.result import process_preambles
49-
preamble = "".join(process_preambles(getattr(code, "device_preambles", [])))
52+
preamble = "".join(
53+
process_preambles(getattr(code, "device_preambles", [])))
5054
device_code = "\n\n".join(str(dp.ast) for dp in code.device_programs)
51-
return preamble + "\nextern \"C\" {\n" + device_code + "\n}\n"
55+
return preamble + '\nextern "C" {\n' + device_code + "\n}\n"
5256
return code.device_code()
5357

5458
@PETSc.Log.EventDecorator()
@@ -81,32 +85,40 @@ def compile(self, comm):
8185

8286
class Parloop(AbstractParloop):
8387

84-
def prepare_arglist(self, iterset, *args):
85-
arglist = iterset._kernel_args_
86-
for arg in args:
87-
arglist += arg._kernel_args_
88-
seen = set()
89-
for arg in args:
90-
maps = arg.map_tuple
91-
for map_ in maps:
92-
if map_ is None:
93-
continue
94-
for k in map_._kernel_args_:
95-
if k in seen:
96-
continue
97-
arglist += (k,)
98-
seen.add(k)
99-
return arglist
100-
88+
@PETSc.Log.EventDecorator("ParLoopRednBegin")
89+
@mpi.collective
90+
def reduction_begin(self):
91+
"""Begin reductions."""
92+
requests = []
93+
for idx in self._reduction_idxs:
94+
glob = self.arguments[idx].data
95+
mpi_op = {INC: mpi.MPI.SUM,
96+
MIN: mpi.MPI.MIN,
97+
MAX: mpi.MPI.MAX}.get(self.accesses[idx])
98+
99+
if mpi.MPI.VERSION >= 3:
100+
requests.append(self.comm.Iallreduce(glob._data,
101+
glob._buf,
102+
op=mpi_op))
103+
else:
104+
self.comm.Allreduce(glob._data, glob._buf, op=mpi_op)
105+
return tuple(requests)
106+
107+
@PETSc.Log.EventDecorator("ParLoopRednEnd")
101108
@mpi.collective
102-
def _compute(self, part):
103-
"""Execute the kernel over all members of a MPI-part of the iteration space.
109+
def reduction_end(self, requests):
110+
"""Finish reductions."""
111+
if mpi.MPI.VERSION >= 3:
112+
for idx, req in zip(self._reduction_idxs, requests):
113+
req.Wait()
114+
glob = self.arguments[idx].data
115+
glob._data[:] = glob._buf
116+
else:
117+
assert len(requests) == 0
104118

105-
:arg part: The :class:`SetPartition` to compute over.
106-
"""
107-
with self._compute_event():
108-
PETSc.Log.logFlops(part.size*self.num_flops)
109-
self.global_kernel(self.comm, part.offset, part.offset+part.size, *self.arglist)
119+
for idx in self._reduction_idxs:
120+
glob = self.arguments[idx].data
121+
glob._data[:] = glob._buf
110122

111123

112124
class CPUBackend(AbstractComputeBackend):
@@ -126,7 +138,7 @@ class CPUBackend(AbstractComputeBackend):
126138
Mat = Mat
127139
Global = Global
128140
GlobalDataSet = GlobalDataSet
129-
PETScVecType = 'standard'
141+
PETScVecType = "standard"
130142

131143
def turn_on_offloading(self):
132144
pass

‎pyop2/backends/cuda.py

Lines changed: 42 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
AVAILABLE_ON_DEVICE_ONLY,
1111
AVAILABLE_ON_BOTH,
1212
DataAvailability)
13-
from pyop2.profiling import timed_region
1413
from pyop2.configuration import configuration
1514
from pyop2.types.set import (MixedSet, Subset as BaseSubset,
1615
ExtrudedSet as BaseExtrudedSet,
@@ -20,7 +19,7 @@
2019
from pyop2.types.dataset import DataSet, GlobalDataSet, MixedDataSet
2120
from pyop2.types.mat import Mat
2221
from pyop2.types.glob import Global as BaseGlobal
23-
from pyop2.types.access import RW, READ, INC
22+
from pyop2.types.access import RW, READ, INC, MIN, MAX
2423
from pyop2.parloop import AbstractParloop
2524
from pyop2.global_kernel import AbstractGlobalKernel
2625
from pyop2.backends import AbstractComputeBackend, cpu as cpu_backend
@@ -253,6 +252,9 @@ def _kernel_args_(self):
253252
return (self._cuda_data.gpudata,)
254253
else:
255254
self.ensure_availability_on_host()
255+
# tell petsc that we have updated the data on the host
256+
with self.vec as v:
257+
v.array_w
256258
return (self._data.ctypes.data, )
257259
else:
258260
if cuda_backend.offloading:
@@ -365,7 +367,7 @@ def ensure_availability_on_device(self):
365367

366368
def ensure_availability_on_host(self):
367369
if not self.is_available_on_host():
368-
self._cuda.get(ary=self._data)
370+
self._cuda_data.get(ary=self._data)
369371
self._availability_flag = AVAILABLE_ON_BOTH
370372

371373
@property
@@ -538,39 +540,45 @@ def compile(self, comm):
538540

539541
class Parloop(AbstractParloop):
540542

541-
def prepare_arglist(self, iterset, *args):
542-
nbytes = 0
543-
544-
arglist = iterset._kernel_args_
545-
for arg in args:
546-
arglist += arg._kernel_args_
547-
if arg.access is INC:
548-
nbytes += arg.data.nbytes * 2
543+
@PETSc.Log.EventDecorator("ParLoopRednBegin")
544+
@mpi.collective
545+
def reduction_begin(self):
546+
"""Begin reductions."""
547+
requests = []
548+
for idx in self._reduction_idxs:
549+
glob = self.arguments[idx].data
550+
mpi_op = {INC: mpi.MPI.SUM,
551+
MIN: mpi.MPI.MIN,
552+
MAX: mpi.MPI.MAX}.get(self.accesses[idx])
553+
554+
if mpi.MPI.VERSION >= 3:
555+
glob.ensure_availability_on_host()
556+
requests.append(self.comm.Iallreduce(glob._data,
557+
glob._buf,
558+
op=mpi_op))
549559
else:
550-
nbytes += arg.data.nbytes
551-
seen = set()
552-
for arg in args:
553-
maps = arg.map_tuple
554-
for map_ in maps:
555-
for k in map_._kernel_args_:
556-
if k in seen:
557-
continue
558-
arglist += map_._kernel_args_
559-
seen.add(k)
560-
nbytes += map_.values.nbytes
561-
562-
self.nbytes = nbytes
563-
564-
return arglist
560+
self.comm.Allreduce(glob._data, glob._buf, op=mpi_op)
561+
return tuple(requests)
565562

563+
@PETSc.Log.EventDecorator("ParLoopRednEnd")
566564
@mpi.collective
567-
def _compute(self, part, fun, *arglist):
568-
if part.size == 0:
569-
return
565+
def reduction_end(self, requests):
566+
"""Finish reductions."""
567+
if mpi.MPI.VERSION >= 3:
568+
for idx, req in zip(self._reduction_idxs, requests):
569+
req.Wait()
570+
glob = self.arguments[idx].data
571+
glob._data[:] = glob._buf
572+
glob._availability_flag = AVAILABLE_ON_HOST_ONLY
573+
glob.ensure_availability_on_device()
574+
else:
575+
assert len(requests) == 0
570576

571-
with timed_region("Parloop_{0}_{1}".format(self.iterset.name,
572-
self._jitmodule._wrapper_name)):
573-
fun(part.offset, part.offset + part.size, *arglist)
577+
for idx in self._reduction_idxs:
578+
glob = self.arguments[idx].data
579+
glob._data[:] = glob._buf
580+
glob._availability_flag = AVAILABLE_ON_HOST_ONLY
581+
glob.ensure_availability_on_device()
574582

575583

576584
class CUDABackend(AbstractComputeBackend):
@@ -602,12 +610,12 @@ def __init__(self):
602610

603611
def turn_on_offloading(self):
604612
self.offloading = True
605-
self.ParLoop = self.Parloop_offloading
613+
self.Parloop = self.Parloop_offloading
606614
self.GlobalKernel = self.GlobalKernel_offloading
607615

608616
def turn_off_offloading(self):
609617
self.offloading = False
610-
self.ParLoop = self.Parloop_no_offloading
618+
self.Parloop = self.Parloop_no_offloading
611619
self.GlobalKernel = self.GlobalKernel_no_offloading
612620

613621
@property

‎pyop2/backends/opencl.py

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
AVAILABLE_ON_DEVICE_ONLY,
1010
AVAILABLE_ON_BOTH,
1111
DataAvailability)
12-
from pyop2.profiling import timed_region
1312
from pyop2.configuration import configuration
1413
from pyop2.types.set import (MixedSet, Subset as BaseSubset,
1514
ExtrudedSet as BaseExtrudedSet,
@@ -19,7 +18,7 @@
1918
from pyop2.types.dataset import DataSet, GlobalDataSet, MixedDataSet
2019
from pyop2.types.mat import Mat
2120
from pyop2.types.glob import Global as BaseGlobal
22-
from pyop2.types.access import RW, READ, INC
21+
from pyop2.types.access import RW, READ, INC, MIN, MAX
2322
from pyop2.parloop import AbstractParloop
2423
from pyop2.global_kernel import AbstractGlobalKernel
2524
from pyop2.backends import AbstractComputeBackend, cpu as cpu_backend
@@ -582,14 +581,45 @@ def prepare_arglist(self, iterset, *args):
582581

583582
return arglist
584583

584+
@PETSc.Log.EventDecorator("ParLoopRednBegin")
585585
@mpi.collective
586-
def _compute(self, part, fun, *arglist):
587-
if part.size == 0:
588-
return
586+
def reduction_begin(self):
587+
"""Begin reductions."""
588+
requests = []
589+
for idx in self._reduction_idxs:
590+
glob = self.arguments[idx].data
591+
mpi_op = {INC: mpi.MPI.SUM,
592+
MIN: mpi.MPI.MIN,
593+
MAX: mpi.MPI.MAX}.get(self.accesses[idx])
594+
595+
if mpi.MPI.VERSION >= 3:
596+
glob.ensure_availability_on_host()
597+
requests.append(self.comm.Iallreduce(glob._data,
598+
glob._buf,
599+
op=mpi_op))
600+
else:
601+
self.comm.Allreduce(glob._data, glob._buf, op=mpi_op)
602+
return tuple(requests)
603+
604+
@PETSc.Log.EventDecorator("ParLoopRednEnd")
605+
@mpi.collective
606+
def reduction_end(self, requests):
607+
"""Finish reductions."""
608+
if mpi.MPI.VERSION >= 3:
609+
for idx, req in zip(self._reduction_idxs, requests):
610+
req.Wait()
611+
glob = self.arguments[idx].data
612+
glob._data[:] = glob._buf
613+
glob._availability_flag = AVAILABLE_ON_HOST_ONLY
614+
glob.ensure_availability_on_device()
615+
else:
616+
assert len(requests) == 0
589617

590-
with timed_region("Parloop_{0}_{1}".format(self.iterset.name,
591-
self._jitmodule._wrapper_name)):
592-
fun(part.offset, part.offset + part.size, *arglist)
618+
for idx in self._reduction_idxs:
619+
glob = self.arguments[idx].data
620+
glob._data[:] = glob._buf
621+
glob._availability_flag = AVAILABLE_ON_HOST_ONLY
622+
glob.ensure_availability_on_device()
593623

594624

595625
class OpenCLBackend(AbstractComputeBackend):
@@ -642,12 +672,12 @@ def queue(self):
642672

643673
def turn_on_offloading(self):
644674
self.offloading = True
645-
self.ParLoop = self.Parloop_offloading
675+
self.Parloop = self.Parloop_offloading
646676
self.GlobalKernel = self.GlobalKernel_offloading
647677

648678
def turn_off_offloading(self):
649679
self.offloading = False
650-
self.ParLoop = self.Parloop_no_offloading
680+
self.Parloop = self.Parloop_no_offloading
651681
self.GlobalKernel = self.GlobalKernel_no_offloading
652682

653683
@property

‎pyop2/parloop.py

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,16 @@ def compute(self):
192192
# Parloop.compute is an alias for Parloop.__call__
193193
self()
194194

195+
@mpi.collective
196+
def _compute(self, part):
197+
"""Execute the kernel over all members of a MPI-part of the iteration space.
198+
199+
:arg part: The :class:`SetPartition` to compute over.
200+
"""
201+
with self._compute_event():
202+
PETSc.Log.logFlops(part.size*self.num_flops)
203+
self.global_kernel(self.comm, part.offset, part.offset+part.size, *self.arglist)
204+
195205
@PETSc.Log.EventDecorator("ParLoopExecute")
196206
@mpi.collective
197207
def __call__(self):
@@ -341,34 +351,13 @@ def _l2g_idxs(self):
341351
@mpi.collective
342352
def reduction_begin(self):
343353
"""Begin reductions."""
344-
requests = []
345-
for idx in self._reduction_idxs:
346-
glob = self.arguments[idx].data
347-
mpi_op = {Access.INC: mpi.MPI.SUM,
348-
Access.MIN: mpi.MPI.MIN,
349-
Access.MAX: mpi.MPI.MAX}.get(self.accesses[idx])
350-
351-
if mpi.MPI.VERSION >= 3:
352-
requests.append(self.comm.Iallreduce(glob._data, glob._buf, op=mpi_op))
353-
else:
354-
self.comm.Allreduce(glob._data, glob._buf, op=mpi_op)
355-
return tuple(requests)
354+
raise NotImplementedError("Backend-specific logic not implemented")
356355

357356
@PETSc.Log.EventDecorator("ParLoopRednEnd")
358357
@mpi.collective
359358
def reduction_end(self, requests):
360359
"""Finish reductions."""
361-
if mpi.MPI.VERSION >= 3:
362-
for idx, req in zip(self._reduction_idxs, requests):
363-
req.Wait()
364-
glob = self.arguments[idx].data
365-
glob.data[:] = glob._buf
366-
else:
367-
assert len(requests) == 0
368-
369-
for idx in self._reduction_idxs:
370-
glob = self.arguments[idx].data
371-
glob.data[:] = glob._buf
360+
raise NotImplementedError("Backend-specific logic not implemented")
372361

373362
@cached_property
374363
def _reduction_idxs(self):

0 commit comments

Comments
 (0)
Please sign in to comment.