Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion xobjects/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import logging
import os
import weakref
import xobjects as xo
from abc import ABC, abstractmethod
from collections import defaultdict
from importlib.metadata import entry_points
from pathlib import Path
from typing import (
Dict,
Expand Down Expand Up @@ -354,6 +354,26 @@ def build_kernels(
) -> Dict[Tuple[str, tuple], KernelType]:
pass

def get_installed_c_source_paths(self) -> List[str]:
"""Returns a list of include paths registered in dependent packages.

In a package that depends on xobjects, you can register C source paths
using the entry point `xobjects.c_sources`. A path to the directory
containing the specified module will be added to the include path when
building kernels. For example, the following will allow to write
``#include <xtrack/path/to/some/header.h>`` in kernel sources:

.. code-block:: toml
[project.entry-points.xobjects]
include = "xtrack"
"""
sources = []
for ep in entry_points(group='xobjects', name='include'):
module = ep.load()
path = Path(module.__file__).parents[1]
sources.append(str(path))
return sources

@abstractmethod
def nparray_to_context_array(self, arr):
return arr
Expand Down
5 changes: 5 additions & 0 deletions xobjects/context_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,11 @@ def compile_kernel(
xtr_compile_args.append("-DXO_CONTEXT_CPU_SERIAL")
xtr_link_args.append("-DXO_CONTEXT_CPU_SERIAL")

extra_include_paths = self.get_installed_c_source_paths()
include_flags = [f'-I{path}' for path in extra_include_paths]
xtr_compile_args.extend(include_flags)
xtr_link_args.extend(include_flags)

if os.name == "nt": # windows
# TODO: to be handled properly
xtr_compile_args = []
Expand Down
4 changes: 4 additions & 0 deletions xobjects/context_cupy.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,10 @@ def build_kernels(
fid.write(specialized_source)

extra_compile_args = (*extra_compile_args, "-DXO_CONTEXT_CUDA")
extra_include_paths = self.get_installed_c_source_paths()
include_flags = [f'-I{path}' for path in extra_include_paths]
xtr_compile_args.extend(include_flags)

module = cupy.RawModule(
code=specialized_source, options=extra_compile_args
)
Expand Down
4 changes: 4 additions & 0 deletions xobjects/context_pyopencl.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,12 @@ def build_kernels(
with open(save_source_as, "w") as fid:
fid.write(specialized_source)

extra_include_paths = self.get_installed_c_source_paths()
include_flags = [f'-I{path}' for path in extra_include_paths]

extra_compile_args = (
*extra_compile_args,
*include_flags,
"-cl-std=CL2.0",
"-DXO_CONTEXT_CL",
)
Expand Down
Loading