Skip to content

Commit ddbe4ec

Browse files
Allow packages to register include paths with xobjects
1 parent 9fd00a5 commit ddbe4ec

File tree

4 files changed

+34
-1
lines changed

4 files changed

+34
-1
lines changed

xobjects/context.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
import logging
77
import os
88
import weakref
9-
import xobjects as xo
109
from abc import ABC, abstractmethod
1110
from collections import defaultdict
11+
from importlib.metadata import entry_points
1212
from pathlib import Path
1313
from typing import (
1414
Dict,
@@ -354,6 +354,26 @@ def build_kernels(
354354
) -> Dict[Tuple[str, tuple], KernelType]:
355355
pass
356356

357+
def get_installed_c_source_paths(self) -> List[str]:
358+
"""Returns a list of include paths registered in dependent packages.
359+
360+
In a package that depends on xobjects, you can register C source paths
361+
using the entry point `xobjects.c_sources`. A path to the directory
362+
containing the specified module will be added to the include path when
363+
building kernels. For example, the following will allow to write
364+
``#include <xtrack/path/to/some/header.h>`` in kernel sources:
365+
366+
.. code-block:: toml
367+
[project.entry-points.xobjects]
368+
include = "xtrack"
369+
"""
370+
sources = []
371+
for ep in entry_points(group='xobjects', name='include'):
372+
module = ep.load()
373+
path = Path(module.__file__).parents[1]
374+
sources.append(str(path))
375+
return sources
376+
357377
@abstractmethod
358378
def nparray_to_context_array(self, arr):
359379
return arr

xobjects/context_cpu.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,11 @@ def compile_kernel(
426426
xtr_compile_args.append("-DXO_CONTEXT_CPU_SERIAL")
427427
xtr_link_args.append("-DXO_CONTEXT_CPU_SERIAL")
428428

429+
extra_include_paths = self.get_installed_c_source_paths()
430+
include_flags = [f'-I{path}' for path in extra_include_paths]
431+
xtr_compile_args.extend(include_flags)
432+
xtr_link_args.extend(include_flags)
433+
429434
if os.name == "nt": # windows
430435
# TODO: to be handled properly
431436
xtr_compile_args = []

xobjects/context_cupy.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,10 @@ def build_kernels(
456456
fid.write(specialized_source)
457457

458458
extra_compile_args = (*extra_compile_args, "-DXO_CONTEXT_CUDA")
459+
extra_include_paths = self.get_installed_c_source_paths()
460+
include_flags = [f'-I{path}' for path in extra_include_paths]
461+
xtr_compile_args.extend(include_flags)
462+
459463
module = cupy.RawModule(
460464
code=specialized_source, options=extra_compile_args
461465
)

xobjects/context_pyopencl.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,8 +219,12 @@ def build_kernels(
219219
with open(save_source_as, "w") as fid:
220220
fid.write(specialized_source)
221221

222+
extra_include_paths = self.get_installed_c_source_paths()
223+
include_flags = [f'-I{path}' for path in extra_include_paths]
224+
222225
extra_compile_args = (
223226
*extra_compile_args,
227+
*include_flags,
224228
"-cl-std=CL2.0",
225229
"-DXO_CONTEXT_CL",
226230
)

0 commit comments

Comments
 (0)