diff --git a/CMakeLists.txt b/CMakeLists.txt index f4964209..f8ae7644 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -80,11 +80,17 @@ add_subdirectory(src/cpp_bindings) pybind11_add_module(pyGinkgoBindings ${PYGINKGO_CPP_SOURCES}) +pybind11_add_module(pyGinkgoExtensions ${EXTENSIONS_CPP_SOURCES}) + target_include_directories(pyGinkgoBindings PRIVATE ${Ginkgo_BINARY_DIR}/include) +target_include_directories(pyGinkgoExtensions + PRIVATE ${Ginkgo_BINARY_DIR}/include) target_link_libraries(pyGinkgoBindings PRIVATE Ginkgo::ginkgo) +target_link_libraries(pyGinkgoExtensions PRIVATE Ginkgo::ginkgo) target_link_libraries(pyGinkgoBindings PUBLIC nlohmann_json::nlohmann_json) set_property(TARGET pyGinkgoBindings PROPERTY CXX_STANDARD 14) +set_property(TARGET pyGinkgoExtensions PROPERTY CXX_STANDARD 14) # Disable false positive warnings from GCC (>= 12.4 and < # 13)(https://gcc.gnu.org/bugzilla/show_bug.cgi?id=115824 and @@ -94,6 +100,8 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 13.0.0) target_compile_options(pyGinkgoBindings PRIVATE -Wno-array-bounds -Wno-stringop-overread) + target_compile_options(pyGinkgoExtensions PRIVATE -Wno-array-bounds + -Wno-stringop-overread) message( WARNING "GCC >= 12.4.0 < 13 gives false positive warning, so we need to add \ diff --git a/src/cpp_bindings/CMakeLists.txt b/src/cpp_bindings/CMakeLists.txt index b506ac87..3194189b 100644 --- a/src/cpp_bindings/CMakeLists.txt +++ b/src/cpp_bindings/CMakeLists.txt @@ -19,3 +19,7 @@ set(PYGINKGO_CPP_SOURCES ${PROJECT_SOURCE_DIR}/src/cpp_bindings/solver/triangular.cpp ${PROJECT_SOURCE_DIR}/src/cpp_bindings/factorization/factorization.cpp PARENT_SCOPE) + +set(EXTENSIONS_CPP_SOURCES + ${PROJECT_SOURCE_DIR}/src/cpp_bindings/pylinop.cpp + PARENT_SCOPE) diff --git a/src/cpp_bindings/pylinop.cpp b/src/cpp_bindings/pylinop.cpp new file mode 100644 index 00000000..cca64ec6 --- /dev/null +++ b/src/cpp_bindings/pylinop.cpp @@ -0,0 +1,94 @@ +// SPDX-FileCopyrightText: 2026 pyGinkgo authors +// +// SPDX-License-Identifier: MIT + +#include + +#include + +#include + + +namespace py = pybind11; + + +namespace gko { + + +class PyLinOp : public EnableLinOp { +public: + friend class EnableLinOp; + friend class EnablePolymorphicObject; + +public: + void apply_impl(const gko::LinOp *b, gko::LinOp *x) const override + { + GKO_NOT_IMPLEMENTED; + } + + void apply_impl(const gko::LinOp *alpha, const gko::LinOp *b, + const gko::LinOp *beta, gko::LinOp *x) const override + { + GKO_NOT_IMPLEMENTED; + } + + explicit PyLinOp(std::shared_ptr exec, + gko::dim<2> dim = gko::dim<2>{}) + : gko::EnableLinOp(std::move(exec), dim) + {} +}; + + +class PyLinOpTrampoline : public PyLinOp { + using PyLinOp::PyLinOp; + +public: + void apply_impl(const gko::LinOp *b, gko::LinOp *x) const override + { + PYBIND11_OVERRIDE(void, PyLinOp, apply_impl, b, x); + } +}; + + +std::unique_ptr create(std::shared_ptr exec) +{ + return std::unique_ptr{new PyLinOpTrampoline{exec}}; +} + +std::unique_ptr create(std::shared_ptr exec, + gko::dim<2> dim) +{ + return std::unique_ptr{new PyLinOpTrampoline{exec, dim}}; +} + + +} // namespace gko + + +class Publicist + : public gko::LinOp { // helper type for exposing protected functions +public: + using gko::LinOp::apply_impl; // inherited with different access modifier +}; + +PYBIND11_MODULE(pyGinkgoExtensions, m, py::mod_gil_not_used()) +{ + py::class_>(m, "PyLinOp") + .def(py::init([](std::shared_ptr exec) { + auto A = std::shared_ptr(create(exec)); + return A; + })) + .def(py::init( + [](std::shared_ptr exec, py::tuple dim) { + auto A = std::shared_ptr(create( + exec, + gko::dim<2>{dim[0].cast(), dim[1].cast()})); + return A; + })) + .def("apply_impl", py::overload_cast( + &Publicist::apply_impl, py::const_)); + + m.def("call_apply", [](const gko::LinOp *linop, const gko::LinOp *b, + gko::LinOp *x) { linop->apply(b, x); }); +} diff --git a/tests/cpp_bindings/CMakeLists.txt b/tests/cpp_bindings/CMakeLists.txt index 61dbf654..3beb8bd9 100644 --- a/tests/cpp_bindings/CMakeLists.txt +++ b/tests/cpp_bindings/CMakeLists.txt @@ -25,6 +25,7 @@ pyginkgo_bindings_test(iterative_solver) pyginkgo_bindings_test(direct_solver) pyginkgo_bindings_test(torch) pyginkgo_bindings_test(factorization) +pyginkgo_bindings_test(pylinop) # add CUDA specific tests, if building with CUDA enabled if(GINKGO_BUILD_CUDA)