diff --git a/xobjects/context.py b/xobjects/context.py index a2c71e0..2d27a04 100644 --- a/xobjects/context.py +++ b/xobjects/context.py @@ -6,9 +6,9 @@ import logging import os import weakref -import xobjects as xo from abc import ABC, abstractmethod from collections import defaultdict +from importlib.metadata import entry_points from pathlib import Path from typing import ( Dict, @@ -354,6 +354,26 @@ def build_kernels( ) -> Dict[Tuple[str, tuple], KernelType]: pass + def get_installed_c_source_paths(self) -> List[str]: + """Returns a list of include paths registered in dependent packages. + + In a package that depends on xobjects, you can register C source paths + using the entry point `xobjects.c_sources`. A path to the directory + containing the specified module will be added to the include path when + building kernels. For example, the following will allow to write + ``#include `` in kernel sources: + + .. code-block:: toml + [project.entry-points.xobjects] + include = "xtrack" + """ + sources = [] + for ep in entry_points(group='xobjects', name='include'): + module = ep.load() + path = Path(module.__file__).parents[1] + sources.append(str(path)) + return sources + @abstractmethod def nparray_to_context_array(self, arr): return arr diff --git a/xobjects/context_cpu.py b/xobjects/context_cpu.py index 6234a9a..33f45f4 100644 --- a/xobjects/context_cpu.py +++ b/xobjects/context_cpu.py @@ -426,6 +426,11 @@ def compile_kernel( xtr_compile_args.append("-DXO_CONTEXT_CPU_SERIAL") xtr_link_args.append("-DXO_CONTEXT_CPU_SERIAL") + extra_include_paths = self.get_installed_c_source_paths() + include_flags = [f'-I{path}' for path in extra_include_paths] + xtr_compile_args.extend(include_flags) + xtr_link_args.extend(include_flags) + if os.name == "nt": # windows # TODO: to be handled properly xtr_compile_args = [] diff --git a/xobjects/context_cupy.py b/xobjects/context_cupy.py index fc16e2b..2d51b91 100644 --- a/xobjects/context_cupy.py +++ b/xobjects/context_cupy.py @@ -456,6 +456,10 @@ def build_kernels( fid.write(specialized_source) extra_compile_args = (*extra_compile_args, "-DXO_CONTEXT_CUDA") + extra_include_paths = self.get_installed_c_source_paths() + include_flags = [f'-I{path}' for path in extra_include_paths] + xtr_compile_args.extend(include_flags) + module = cupy.RawModule( code=specialized_source, options=extra_compile_args ) diff --git a/xobjects/context_pyopencl.py b/xobjects/context_pyopencl.py index 4424891..f9893d8 100644 --- a/xobjects/context_pyopencl.py +++ b/xobjects/context_pyopencl.py @@ -219,8 +219,12 @@ def build_kernels( with open(save_source_as, "w") as fid: fid.write(specialized_source) + extra_include_paths = self.get_installed_c_source_paths() + include_flags = [f'-I{path}' for path in extra_include_paths] + extra_compile_args = ( *extra_compile_args, + *include_flags, "-cl-std=CL2.0", "-DXO_CONTEXT_CL", )