22from pyop2 .types .set import Set , ExtrudedSet , Subset , MixedSet
33from pyop2 .types .dataset import DataSet , GlobalDataSet , MixedDataSet
44from pyop2 .types .map import Map , MixedMap
5- from pyop2 .parloop import AbstractParLoop , AbstractJITModule
5+ from pyop2 .parloop import AbstractParLoop
6+ from pyop2 .global_kernel import AbstractGlobalKernel
67from pyop2 .types .mat import Mat
78from pyop2 .glob import Global
89from pyop2 .backends import AbstractComputeBackend
10+ from pyop2 .datatypes import as_ctypes , IntType
911from petsc4py import PETSc
1012from . import (
1113 compilation ,
12- configuration as conf ,
13- datatypes as dtypes ,
1414 mpi ,
1515 utils
1616)
@@ -35,23 +35,17 @@ def _vec(self):
3535 return vec
3636
3737
38- class JITModule (AbstractJITModule ):
38+ class GlobalKernel (AbstractGlobalKernel ):
39+
3940 @utils .cached_property
4041 def code_to_compile (self ):
41- from pyop2 . codegen . builder import WrapperBuilder
42+ """Return the C/C++ source code as a string."""
4243 from pyop2 .codegen .rep2loopy import generate
4344
44- builder = WrapperBuilder (kernel = self ._kernel ,
45- iterset = self ._iterset ,
46- iteration_region = self ._iteration_region ,
47- pass_layer_to_kernel = self ._pass_layer_arg )
48- for arg in self ._args :
49- builder .add_argument (arg )
50-
51- wrapper = generate (builder )
45+ wrapper = generate (self .builder )
5246 code = lp .generate_code_v2 (wrapper )
5347
54- if self ._kernel . _cpp :
48+ if self .local_kernel . cpp :
5549 from loopy .codegen .result import process_preambles
5650 preamble = "" .join (process_preambles (getattr (code , "device_preambles" , [])))
5751 device_code = "\n \n " .join (str (dp .ast ) for dp in code .device_programs )
@@ -60,53 +54,38 @@ def code_to_compile(self):
6054
6155 @PETSc .Log .EventDecorator ()
6256 @mpi .collective
63- def compile (self ):
64- # If we weren't in the cache we /must/ have arguments
65- if not hasattr (self , '_args' ):
66- raise RuntimeError ("JITModule has no args associated with it, should never happen" )
67-
68- compiler = conf .configuration ["compiler" ]
69- extension = "cpp" if self ._kernel ._cpp else "c"
70- cppargs = self ._cppargs
71- cppargs += ["-I%s/include" % d for d in utils .get_petsc_dir ()] + \
72- ["-I%s" % d for d in self ._kernel ._include_dirs ] + \
73- ["-I%s" % os .path .abspath (os .path .dirname (__file__ ))]
74- ldargs = ["-L%s/lib" % d for d in utils .get_petsc_dir ()] + \
75- ["-Wl,-rpath,%s/lib" % d for d in utils .get_petsc_dir ()] + \
76- ["-lpetsc" , "-lm" ] + self ._libraries
77- ldargs += self ._kernel ._ldargs
78-
79- self ._fun = compilation .load (self ,
80- extension ,
81- self ._wrapper_name ,
82- cppargs = cppargs ,
83- ldargs = ldargs ,
84- restype = ctypes .c_int ,
85- compiler = compiler ,
86- comm = self .comm )
87- # Blow away everything we don't need any more
88- del self ._args
89- del self ._kernel
90- del self ._iterset
57+ def compile (self , comm ):
58+ """Compile the kernel.
59+
60+ :arg comm: The communicator the compilation is collective over.
61+ :returns: A ctypes function pointer for the compiled function.
62+ """
63+ extension = "cpp" if self .local_kernel .cpp else "c"
64+ cppargs = (
65+ tuple ("-I%s/include" % d for d in utils .get_petsc_dir ())
66+ + tuple ("-I%s" % d for d in self .local_kernel .include_dirs )
67+ + ("-I%s" % os .path .abspath (os .path .dirname (__file__ )),)
68+ )
69+ ldargs = (
70+ tuple ("-L%s/lib" % d for d in utils .get_petsc_dir ())
71+ + tuple ("-Wl,-rpath,%s/lib" % d for d in utils .get_petsc_dir ())
72+ + ("-lpetsc" , "-lm" )
73+ + tuple (self .local_kernel .ldargs )
74+ )
75+
76+ return compilation .load (self , extension , self .name ,
77+ cppargs = cppargs ,
78+ ldargs = ldargs ,
79+ restype = ctypes .c_int ,
80+ comm = comm )
9181
9282 @utils .cached_property
9383 def argtypes (self ):
94- index_type = dtypes .as_ctypes (dtypes .IntType )
95- argtypes = (index_type , index_type )
96- argtypes += self ._iterset ._argtypes_
97- for arg in self ._args :
98- argtypes += arg ._argtypes_
99- seen = set ()
100- for arg in self ._args :
101- maps = arg .map_tuple
102- for map_ in maps :
103- for k , t in zip (map_ ._kernel_args_ , map_ ._argtypes_ ):
104- if k in seen :
105- continue
106- argtypes += (t ,)
107- seen .add (k )
108- return argtypes
109- ...
84+ # The first two arguments to the global kernel are the 'start' and 'stop'
85+ # indices. All other arguments are declared to be void pointers.
86+ dtypes = [as_ctypes (IntType )] * 2
87+ dtypes .extend ([ctypes .c_voidp for _ in self .builder .wrapper_args [2 :]])
88+ return tuple (dtypes )
11089
11190
11291class ParLoop (AbstractParLoop ):
@@ -128,12 +107,6 @@ def prepare_arglist(self, iterset, *args):
128107 seen .add (k )
129108 return arglist
130109
131- @utils .cached_property
132- def _jitmodule (self ):
133- return JITModule (self .kernel , self .iterset , * self .args ,
134- iterate = self .iteration_region ,
135- pass_layer_arg = self ._pass_layer_arg )
136-
137110 @mpi .collective
138111 def _compute (self , part , fun , * arglist ):
139112 with self ._compute_event :
0 commit comments