Skip to content

Commit 74a33ad

Browse files
Add context-based define flags and propagate extra compile flags
1 parent 555de70 commit 74a33ad

File tree

6 files changed

+34
-9
lines changed

6 files changed

+34
-9
lines changed

xobjects/context.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ def add_kernels(
254254
extra_classes: Sequence[Type] = (),
255255
extra_headers: Sequence[SourceType] = (),
256256
compile: bool = True, # noqa
257+
extra_compile_args: Sequence[str] = (),
257258
):
258259
"""
259260
Adds user-defined kernels to the context. The kernel source
@@ -333,6 +334,7 @@ def add_kernels(
333334
extra_classes=extra_classes,
334335
extra_headers=extra_headers,
335336
compile=compile,
337+
extra_compile_args=extra_compile_args,
336338
)
337339
self.kernels.update(generated_kernels)
338340

@@ -348,6 +350,7 @@ def build_kernels(
348350
extra_classes: Sequence[Type],
349351
extra_headers: Sequence[SourceType],
350352
compile: bool,
353+
extra_compile_args: Sequence[str],
351354
) -> Dict[Tuple[str, tuple], KernelType]:
352355
pass
353356

xobjects/context_cpu.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,8 @@ def add_kernels(
169169
specialize=True,
170170
apply_to_source=(),
171171
save_source_as=None,
172-
extra_compile_args: Sequence[str] = ("-O3", "-Wno-unused-function"),
173-
extra_link_args: Sequence[str] = ("-O3",),
172+
extra_compile_args: Sequence[str] = (),
173+
extra_link_args: Sequence[str] = (),
174174
extra_cdef="",
175175
extra_classes=(),
176176
extra_headers=(),
@@ -271,13 +271,16 @@ def build_kernels(
271271
specialize=True,
272272
apply_to_source=(),
273273
save_source_as=None,
274-
extra_compile_args=("-O3", "-Wno-unused-function"),
275-
extra_link_args=("-O3",),
274+
extra_compile_args=(),
275+
extra_link_args=(),
276276
extra_cdef="",
277277
extra_classes=(),
278278
extra_headers=(),
279279
compile=True, # noqa
280280
) -> Dict[Tuple[str, tuple], "KernelCpu"]:
281+
extra_compile_args += ("-O3", "-Wno-unused-function")
282+
extra_link_args += ("-O3",)
283+
281284
# Determine names and paths
282285
clean_up_so = not module_name
283286
module_name = module_name or str(uuid.uuid4().hex)
@@ -409,20 +412,25 @@ def compile_kernel(
409412
ffi_interface.cdef("int omp_get_max_threads();")
410413

411414
# Compile
412-
xtr_compile_args = ["-std=c99"]
413-
xtr_link_args = ["-std=c99"]
415+
xtr_compile_args = ["-std=c99", "-DXO_CONTEXT_CPU"]
416+
xtr_link_args = ["-std=c99", "-DXO_CONTEXT_CPU"]
414417
xtr_compile_args += extra_compile_args
415418
xtr_link_args += extra_link_args
416419

417420
if self.openmp_enabled:
418421
xtr_compile_args.append("-fopenmp")
419422
xtr_link_args.append("-fopenmp")
423+
xtr_compile_args.append("-DXO_CONTEXT_CPU_OPENMP")
424+
xtr_link_args.append("-DXO_CONTEXT_CPU_OPENMP")
420425

421426
# https://mac.r-project.org/openmp/
422427
# on macos comment the above and uncomment the below flags to compile OpenMP with Xcode clang:
423428
# xtr_compile_args.append("-Xclang")
424429
# xtr_compile_args.append("-fopenmp")
425430
# xtr_link_args.append("-lomp")
431+
else:
432+
xtr_compile_args.append("-DXO_CONTEXT_CPU_SERIAL")
433+
xtr_link_args.append("-DXO_CONTEXT_CPU_SERIAL")
426434

427435
if os.name == "nt": # windows
428436
# TODO: to be handled properly

xobjects/context_cupy.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,7 @@ def build_kernels(
416416
specialize=True,
417417
apply_to_source=(),
418418
save_source_as=None,
419+
extra_compile_args=(),
419420
extra_cdef=None,
420421
extra_classes=(),
421422
extra_headers=(),
@@ -454,7 +455,10 @@ def build_kernels(
454455
with open(save_source_as, "w") as fid:
455456
fid.write(specialized_source)
456457

457-
module = cupy.RawModule(code=specialized_source)
458+
extra_compile_args = (*extra_compile_args, "-DXO_CONTEXT_CUDA")
459+
module = cupy.RawModule(
460+
code=specialized_source, options=extra_compile_args
461+
)
458462

459463
out_kernels = {}
460464
for pyname, kernel in kernel_descriptions.items():

xobjects/context_pyopencl.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ def build_kernels(
180180
specialize=True,
181181
apply_to_source=(),
182182
save_source_as=None,
183+
extra_compile_args=(),
183184
extra_cdef=None,
184185
extra_classes=(),
185186
extra_headers=(),
@@ -218,8 +219,13 @@ def build_kernels(
218219
with open(save_source_as, "w") as fid:
219220
fid.write(specialized_source)
220221

222+
extra_compile_args = (
223+
*extra_compile_args,
224+
"-cl-std=CL2.0",
225+
"-DXO_CONTEXT_CL",
226+
)
221227
prg = cl.Program(self.context, specialized_source).build(
222-
options="-cl-std=CL2.0",
228+
options=extra_compile_args,
223229
)
224230

225231
out_kernels = {}

xobjects/hybrid_class.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def default(self, obj):
9696
def _build_xofields_dict(bases, data):
9797
if "_xofields" in data.keys():
9898
xofields = data["_xofields"].copy()
99-
elif any(map(lambda b: hasattr(b, "_xofields"), bases)):
99+
elif any(hasattr(b, "_xofields") for b in bases):
100100
n_filled = 0
101101
for bb in bases:
102102
if hasattr(bb, "_xofields") and len(bb._xofields.keys()) > 0:

xobjects/struct.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,7 @@ def compile_class_kernels(
473473
apply_to_source=(),
474474
save_source_as=None,
475475
extra_classes=(),
476+
extra_compile_args=(),
476477
):
477478
if only_if_needed:
478479
all_found = True
@@ -489,6 +490,7 @@ def compile_class_kernels(
489490
extra_classes=[cls] + list(extra_classes),
490491
apply_to_source=apply_to_source,
491492
save_source_as=save_source_as,
493+
extra_compile_args=extra_compile_args,
492494
)
493495

494496
def compile_kernels(
@@ -497,13 +499,15 @@ def compile_kernels(
497499
apply_to_source=(),
498500
save_source_as=None,
499501
extra_classes=(),
502+
extra_compile_args=(),
500503
):
501504
self.compile_class_kernels(
502505
context=self._context,
503506
only_if_needed=only_if_needed,
504507
apply_to_source=apply_to_source,
505508
save_source_as=save_source_as,
506509
extra_classes=extra_classes,
510+
extra_compile_args=extra_compile_args,
507511
)
508512

509513

0 commit comments

Comments
 (0)