Skip to content

Commit b75090a

Browse files
committed
Merge branch 'master' into gpu
2 parents e937806 + a33fedf commit b75090a

27 files changed

+2124
-1797
lines changed

pyop2/backends/cpu.py

Lines changed: 37 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22
from pyop2.types.set import Set, ExtrudedSet, Subset, MixedSet
33
from pyop2.types.dataset import DataSet, GlobalDataSet, MixedDataSet
44
from pyop2.types.map import Map, MixedMap
5-
from pyop2.parloop import AbstractParLoop, AbstractJITModule
5+
from pyop2.parloop import AbstractParLoop
6+
from pyop2.global_kernel import AbstractGlobalKernel
67
from pyop2.types.mat import Mat
78
from pyop2.glob import Global
89
from pyop2.backends import AbstractComputeBackend
10+
from pyop2.datatypes import as_ctypes, IntType
911
from petsc4py import PETSc
1012
from . import (
1113
compilation,
12-
configuration as conf,
13-
datatypes as dtypes,
1414
mpi,
1515
utils
1616
)
@@ -35,23 +35,17 @@ def _vec(self):
3535
return vec
3636

3737

38-
class JITModule(AbstractJITModule):
38+
class GlobalKernel(AbstractGlobalKernel):
39+
3940
@utils.cached_property
4041
def code_to_compile(self):
41-
from pyop2.codegen.builder import WrapperBuilder
42+
"""Return the C/C++ source code as a string."""
4243
from pyop2.codegen.rep2loopy import generate
4344

44-
builder = WrapperBuilder(kernel=self._kernel,
45-
iterset=self._iterset,
46-
iteration_region=self._iteration_region,
47-
pass_layer_to_kernel=self._pass_layer_arg)
48-
for arg in self._args:
49-
builder.add_argument(arg)
50-
51-
wrapper = generate(builder)
45+
wrapper = generate(self.builder)
5246
code = lp.generate_code_v2(wrapper)
5347

54-
if self._kernel._cpp:
48+
if self.local_kernel.cpp:
5549
from loopy.codegen.result import process_preambles
5650
preamble = "".join(process_preambles(getattr(code, "device_preambles", [])))
5751
device_code = "\n\n".join(str(dp.ast) for dp in code.device_programs)
@@ -60,53 +54,38 @@ def code_to_compile(self):
6054

6155
@PETSc.Log.EventDecorator()
6256
@mpi.collective
63-
def compile(self):
64-
# If we weren't in the cache we /must/ have arguments
65-
if not hasattr(self, '_args'):
66-
raise RuntimeError("JITModule has no args associated with it, should never happen")
67-
68-
compiler = conf.configuration["compiler"]
69-
extension = "cpp" if self._kernel._cpp else "c"
70-
cppargs = self._cppargs
71-
cppargs += ["-I%s/include" % d for d in utils.get_petsc_dir()] + \
72-
["-I%s" % d for d in self._kernel._include_dirs] + \
73-
["-I%s" % os.path.abspath(os.path.dirname(__file__))]
74-
ldargs = ["-L%s/lib" % d for d in utils.get_petsc_dir()] + \
75-
["-Wl,-rpath,%s/lib" % d for d in utils.get_petsc_dir()] + \
76-
["-lpetsc", "-lm"] + self._libraries
77-
ldargs += self._kernel._ldargs
78-
79-
self._fun = compilation.load(self,
80-
extension,
81-
self._wrapper_name,
82-
cppargs=cppargs,
83-
ldargs=ldargs,
84-
restype=ctypes.c_int,
85-
compiler=compiler,
86-
comm=self.comm)
87-
# Blow away everything we don't need any more
88-
del self._args
89-
del self._kernel
90-
del self._iterset
57+
def compile(self, comm):
58+
"""Compile the kernel.
59+
60+
:arg comm: The communicator the compilation is collective over.
61+
:returns: A ctypes function pointer for the compiled function.
62+
"""
63+
extension = "cpp" if self.local_kernel.cpp else "c"
64+
cppargs = (
65+
tuple("-I%s/include" % d for d in utils.get_petsc_dir())
66+
+ tuple("-I%s" % d for d in self.local_kernel.include_dirs)
67+
+ ("-I%s" % os.path.abspath(os.path.dirname(__file__)),)
68+
)
69+
ldargs = (
70+
tuple("-L%s/lib" % d for d in utils.get_petsc_dir())
71+
+ tuple("-Wl,-rpath,%s/lib" % d for d in utils.get_petsc_dir())
72+
+ ("-lpetsc", "-lm")
73+
+ tuple(self.local_kernel.ldargs)
74+
)
75+
76+
return compilation.load(self, extension, self.name,
77+
cppargs=cppargs,
78+
ldargs=ldargs,
79+
restype=ctypes.c_int,
80+
comm=comm)
9181

9282
@utils.cached_property
9383
def argtypes(self):
94-
index_type = dtypes.as_ctypes(dtypes.IntType)
95-
argtypes = (index_type, index_type)
96-
argtypes += self._iterset._argtypes_
97-
for arg in self._args:
98-
argtypes += arg._argtypes_
99-
seen = set()
100-
for arg in self._args:
101-
maps = arg.map_tuple
102-
for map_ in maps:
103-
for k, t in zip(map_._kernel_args_, map_._argtypes_):
104-
if k in seen:
105-
continue
106-
argtypes += (t,)
107-
seen.add(k)
108-
return argtypes
109-
...
84+
# The first two arguments to the global kernel are the 'start' and 'stop'
85+
# indices. All other arguments are declared to be void pointers.
86+
dtypes = [as_ctypes(IntType)] * 2
87+
dtypes.extend([ctypes.c_voidp for _ in self.builder.wrapper_args[2:]])
88+
return tuple(dtypes)
11089

11190

11291
class ParLoop(AbstractParLoop):
@@ -128,12 +107,6 @@ def prepare_arglist(self, iterset, *args):
128107
seen.add(k)
129108
return arglist
130109

131-
@utils.cached_property
132-
def _jitmodule(self):
133-
return JITModule(self.kernel, self.iterset, *self.args,
134-
iterate=self.iteration_region,
135-
pass_layer_arg=self._pass_layer_arg)
136-
137110
@mpi.collective
138111
def _compute(self, part, fun, *arglist):
139112
with self._compute_event:

pyop2/caching.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,15 @@
3333

3434
"""Provides common base classes for cached objects."""
3535

36+
import hashlib
37+
import os
38+
from pathlib import Path
39+
import pickle
3640

41+
import cachetools
42+
43+
from pyop2.configuration import configuration
44+
from pyop2.mpi import hash_comm
3745
from pyop2.utils import cached_property
3846

3947

@@ -230,3 +238,108 @@ def _cache_key(cls, *args, **kwargs):
230238
def cache_key(self):
231239
"""Cache key."""
232240
return self._key
241+
242+
243+
cached = cachetools.cached
244+
"""Cache decorator for functions. See the cachetools documentation for more
245+
information.
246+
247+
.. note::
248+
If you intend to use this decorator to cache things that are collective
249+
across a communicator then you must include the communicator as part of
250+
the cache key. Since communicators are themselves not hashable you should
251+
use :func:`pyop2.mpi.hash_comm`.
252+
253+
You should also make sure to use unbounded caches as otherwise some ranks
254+
may evict results leading to deadlocks.
255+
"""
256+
257+
258+
def disk_cached(cache, cachedir=None, key=cachetools.keys.hashkey, collective=False):
259+
"""Decorator for wrapping a function in a cache that stores values in memory and to disk.
260+
261+
:arg cache: The in-memory cache, usually a :class:`dict`.
262+
:arg cachedir: The location of the cache directory. Defaults to ``PYOP2_CACHE_DIR``.
263+
:arg key: Callable returning the cache key for the function inputs. If ``collective``
264+
is ``True`` then this function must return a 2-tuple where the first entry is the
265+
communicator to be collective over and the second is the key. This is required to ensure
266+
that deadlocks do not occur when using different subcommunicators.
267+
:arg collective: If ``True`` then cache lookup is done collectively over a communicator.
268+
"""
269+
if cachedir is None:
270+
cachedir = configuration["cache_dir"]
271+
272+
def decorator(func):
273+
def wrapper(*args, **kwargs):
274+
if collective:
275+
comm, disk_key = key(*args, **kwargs)
276+
disk_key = _as_hexdigest(disk_key)
277+
k = hash_comm(comm), disk_key
278+
else:
279+
k = _as_hexdigest(key(*args, **kwargs))
280+
281+
# first try the in-memory cache
282+
try:
283+
return cache[k]
284+
except KeyError:
285+
pass
286+
287+
# then try to retrieve from disk
288+
if collective:
289+
if comm.rank == 0:
290+
v = _disk_cache_get(cachedir, disk_key)
291+
comm.bcast(v, root=0)
292+
else:
293+
v = comm.bcast(None, root=0)
294+
else:
295+
v = _disk_cache_get(cachedir, k)
296+
if v is not None:
297+
return cache.setdefault(k, v)
298+
299+
# if all else fails call func and populate the caches
300+
v = func(*args, **kwargs)
301+
if collective:
302+
if comm.rank == 0:
303+
_disk_cache_set(cachedir, disk_key, v)
304+
else:
305+
_disk_cache_set(cachedir, k, v)
306+
return cache.setdefault(k, v)
307+
return wrapper
308+
return decorator
309+
310+
311+
def _as_hexdigest(key):
312+
return hashlib.md5(str(key).encode()).hexdigest()
313+
314+
315+
def _disk_cache_get(cachedir, key):
316+
"""Retrieve a value from the disk cache.
317+
318+
:arg cachedir: The cache directory.
319+
:arg key: The cache key (must be a string).
320+
:returns: The cached object if found, else ``None``.
321+
"""
322+
filepath = Path(cachedir, key[:2], key[2:])
323+
try:
324+
with open(filepath, "rb") as f:
325+
return pickle.load(f)
326+
except FileNotFoundError:
327+
return None
328+
329+
330+
def _disk_cache_set(cachedir, key, value):
331+
"""Store a new value in the disk cache.
332+
333+
:arg cachedir: The cache directory.
334+
:arg key: The cache key (must be a string).
335+
:arg value: The new item to store in the cache.
336+
"""
337+
k1, k2 = key[:2], key[2:]
338+
basedir = Path(cachedir, k1)
339+
basedir.mkdir(parents=True, exist_ok=True)
340+
341+
tempfile = basedir.joinpath(f"{k2}_p{os.getpid()}.tmp")
342+
filepath = basedir.joinpath(k2)
343+
with open(tempfile, "wb") as f:
344+
pickle.dump(value, f)
345+
tempfile.rename(filepath)

0 commit comments

Comments
 (0)