File tree Expand file tree Collapse file tree 4 files changed +34
-1
lines changed Expand file tree Collapse file tree 4 files changed +34
-1
lines changed Original file line number Diff line number Diff line change 66import logging
77import os
88import weakref
9- import xobjects as xo
109from abc import ABC , abstractmethod
1110from collections import defaultdict
11+ from importlib .metadata import entry_points
1212from pathlib import Path
1313from typing import (
1414 Dict ,
@@ -354,6 +354,26 @@ def build_kernels(
354354 ) -> Dict [Tuple [str , tuple ], KernelType ]:
355355 pass
356356
357+ def get_installed_c_source_paths (self ) -> List [str ]:
358+ """Returns a list of include paths registered in dependent packages.
359+
360+ In a package that depends on xobjects, you can register C source paths
361+ using the entry point `xobjects.c_sources`. A path to the directory
362+ containing the specified module will be added to the include path when
363+ building kernels. For example, the following will allow to write
364+ ``#include <xtrack/path/to/some/header.h>`` in kernel sources:
365+
366+ .. code-block:: toml
367+ [project.entry-points.xobjects]
368+ include = "xtrack"
369+ """
370+ sources = []
371+ for ep in entry_points (group = 'xobjects' , name = 'include' ):
372+ module = ep .load ()
373+ path = Path (module .__file__ ).parents [1 ]
374+ sources .append (str (path ))
375+ return sources
376+
357377 @abstractmethod
358378 def nparray_to_context_array (self , arr ):
359379 return arr
Original file line number Diff line number Diff line change @@ -426,6 +426,11 @@ def compile_kernel(
426426 xtr_compile_args .append ("-DXO_CONTEXT_CPU_SERIAL" )
427427 xtr_link_args .append ("-DXO_CONTEXT_CPU_SERIAL" )
428428
429+ extra_include_paths = self .get_installed_c_source_paths ()
430+ include_flags = [f'-I{ path } ' for path in extra_include_paths ]
431+ xtr_compile_args .extend (include_flags )
432+ xtr_link_args .extend (include_flags )
433+
429434 if os .name == "nt" : # windows
430435 # TODO: to be handled properly
431436 xtr_compile_args = []
Original file line number Diff line number Diff line change @@ -456,6 +456,10 @@ def build_kernels(
456456 fid .write (specialized_source )
457457
458458 extra_compile_args = (* extra_compile_args , "-DXO_CONTEXT_CUDA" )
459+ extra_include_paths = self .get_installed_c_source_paths ()
460+ include_flags = [f'-I{ path } ' for path in extra_include_paths ]
461+ xtr_compile_args .extend (include_flags )
462+
459463 module = cupy .RawModule (
460464 code = specialized_source , options = extra_compile_args
461465 )
Original file line number Diff line number Diff line change @@ -219,8 +219,12 @@ def build_kernels(
219219 with open (save_source_as , "w" ) as fid :
220220 fid .write (specialized_source )
221221
222+ extra_include_paths = self .get_installed_c_source_paths ()
223+ include_flags = [f'-I{ path } ' for path in extra_include_paths ]
224+
222225 extra_compile_args = (
223226 * extra_compile_args ,
227+ * include_flags ,
224228 "-cl-std=CL2.0" ,
225229 "-DXO_CONTEXT_CL" ,
226230 )
You can’t perform that action at this time.
0 commit comments