Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,17 @@ add_subdirectory(src/cpp_bindings)

pybind11_add_module(pyGinkgoBindings ${PYGINKGO_CPP_SOURCES})

pybind11_add_module(pyGinkgoExtensions ${EXTENSIONS_CPP_SOURCES})

target_include_directories(pyGinkgoBindings
PRIVATE ${Ginkgo_BINARY_DIR}/include)
target_include_directories(pyGinkgoExtensions
PRIVATE ${Ginkgo_BINARY_DIR}/include)
target_link_libraries(pyGinkgoBindings PRIVATE Ginkgo::ginkgo)
target_link_libraries(pyGinkgoExtensions PRIVATE Ginkgo::ginkgo)
target_link_libraries(pyGinkgoBindings PUBLIC nlohmann_json::nlohmann_json)
set_property(TARGET pyGinkgoBindings PROPERTY CXX_STANDARD 14)
set_property(TARGET pyGinkgoExtensions PROPERTY CXX_STANDARD 14)

# Disable false positive warnings from GCC (>= 12.4 and <
# 13)(https://gcc.gnu.org/bugzilla/show_bug.cgi?id=115824 and
Expand All @@ -94,6 +100,8 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU"
AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 13.0.0)
target_compile_options(pyGinkgoBindings PRIVATE -Wno-array-bounds
-Wno-stringop-overread)
target_compile_options(pyGinkgoExtensions PRIVATE -Wno-array-bounds
-Wno-stringop-overread)
message(
WARNING
"GCC >= 12.4.0 < 13 gives false positive warning, so we need to add \
Expand Down
4 changes: 4 additions & 0 deletions src/cpp_bindings/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,7 @@ set(PYGINKGO_CPP_SOURCES
${PROJECT_SOURCE_DIR}/src/cpp_bindings/solver/triangular.cpp
${PROJECT_SOURCE_DIR}/src/cpp_bindings/factorization/factorization.cpp
PARENT_SCOPE)

set(EXTENSIONS_CPP_SOURCES
${PROJECT_SOURCE_DIR}/src/cpp_bindings/pylinop.cpp
PARENT_SCOPE)
94 changes: 94 additions & 0 deletions src/cpp_bindings/pylinop.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
// SPDX-FileCopyrightText: 2026 pyGinkgo authors
//
// SPDX-License-Identifier: MIT

#include <iostream>

#include <ginkgo/ginkgo.hpp>

#include <pybind11/pybind11.h>


namespace py = pybind11;


namespace gko {


class PyLinOp : public EnableLinOp<PyLinOp> {
public:
friend class EnableLinOp<PyLinOp>;
friend class EnablePolymorphicObject<PyLinOp, LinOp>;

public:
void apply_impl(const gko::LinOp *b, gko::LinOp *x) const override
{
GKO_NOT_IMPLEMENTED;
}

void apply_impl(const gko::LinOp *alpha, const gko::LinOp *b,
const gko::LinOp *beta, gko::LinOp *x) const override
{
GKO_NOT_IMPLEMENTED;
}

explicit PyLinOp(std::shared_ptr<const gko::Executor> exec,
gko::dim<2> dim = gko::dim<2>{})
: gko::EnableLinOp<PyLinOp>(std::move(exec), dim)
{}
};


class PyLinOpTrampoline : public PyLinOp {
using PyLinOp::PyLinOp;

public:
void apply_impl(const gko::LinOp *b, gko::LinOp *x) const override
{
PYBIND11_OVERRIDE(void, PyLinOp, apply_impl, b, x);
}
};


std::unique_ptr<PyLinOp> create(std::shared_ptr<const gko::Executor> exec)
{
return std::unique_ptr<PyLinOp>{new PyLinOpTrampoline{exec}};
}

std::unique_ptr<PyLinOp> create(std::shared_ptr<const gko::Executor> exec,
gko::dim<2> dim)
{
return std::unique_ptr<PyLinOp>{new PyLinOpTrampoline{exec, dim}};
}


} // namespace gko


class Publicist
: public gko::LinOp { // helper type for exposing protected functions
public:
using gko::LinOp::apply_impl; // inherited with different access modifier
};

PYBIND11_MODULE(pyGinkgoExtensions, m, py::mod_gil_not_used())
{
py::class_<gko::PyLinOp, gko::LinOp, gko::PyLinOpTrampoline,
std::shared_ptr<gko::PyLinOp>>(m, "PyLinOp")
.def(py::init([](std::shared_ptr<const gko::Executor> exec) {
auto A = std::shared_ptr(create(exec));
return A;
}))
.def(py::init(
[](std::shared_ptr<const gko::Executor> exec, py::tuple dim) {
auto A = std::shared_ptr(create(
exec,
gko::dim<2>{dim[0].cast<size_t>(), dim[1].cast<size_t>()}));
return A;
}))
.def("apply_impl", py::overload_cast<const gko::LinOp *, gko::LinOp *>(
&Publicist::apply_impl, py::const_));

m.def("call_apply", [](const gko::LinOp *linop, const gko::LinOp *b,
gko::LinOp *x) { linop->apply(b, x); });
}
1 change: 1 addition & 0 deletions tests/cpp_bindings/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pyginkgo_bindings_test(iterative_solver)
pyginkgo_bindings_test(direct_solver)
pyginkgo_bindings_test(torch)
pyginkgo_bindings_test(factorization)
pyginkgo_bindings_test(pylinop)

# add CUDA specific tests, if building with CUDA enabled
if(GINKGO_BUILD_CUDA)
Expand Down
Loading