Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions _pyceres/core/callbacks.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
namespace py = pybind11;

// Trampoline class so we can create an EvaluationCallback in Python.
class PyEvaluationCallBack : public ceres::EvaluationCallback {
class PyEvaluationCallBack : public ceres::EvaluationCallback,
py::trampoline_self_life_support {
public:
/* Inherit the constructors */
using ceres::EvaluationCallback::EvaluationCallback;
Expand All @@ -25,7 +26,8 @@ class PyEvaluationCallBack : public ceres::EvaluationCallback {
}
};

class PyIterationCallback : public ceres::IterationCallback {
class PyIterationCallback : public ceres::IterationCallback,
py::trampoline_self_life_support {
public:
using ceres::IterationCallback::IterationCallback;

Expand All @@ -43,7 +45,7 @@ class PyIterationCallback : public ceres::IterationCallback {
PYBIND11_MAKE_OPAQUE(std::vector<ceres::IterationCallback*>);

void BindCallbacks(py::module& m) {
py::class_<ceres::EvaluationCallback,
py::classh<ceres::EvaluationCallback,
PyEvaluationCallBack /* <--- trampoline*/>(m, "EvaluationCallback")
.def(py::init<>());
}
5 changes: 3 additions & 2 deletions _pyceres/core/cost_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ namespace py = pybind11;

// Class which we can use to create a ceres::CostFunction in python.
// This allows use to create python based cost functions.
class PyCostFunction : public ceres::CostFunction {
class PyCostFunction : public ceres::CostFunction,
py::trampoline_self_life_support {
public:
// Inherit the constructors.
using ceres::CostFunction::CostFunction;
Expand Down Expand Up @@ -91,7 +92,7 @@ class PyCostFunction : public ceres::CostFunction {
};

void BindCostFunctions(py::module& m) {
py::class_<ceres::CostFunction, PyCostFunction /* <--- trampoline*/>(
py::classh<ceres::CostFunction, PyCostFunction /* <--- trampoline*/>(
m, "CostFunction")
.def(py::init<>())
.def("num_residuals", &ceres::CostFunction::num_residuals)
Expand Down
4 changes: 2 additions & 2 deletions _pyceres/core/covariance.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace py = pybind11;

void BindCovariance(py::module& m) {
using Options = ceres::Covariance::Options;
py::class_<Options> PyOptions(m, "CovarianceOptions");
py::classh<Options> PyOptions(m, "CovarianceOptions");
PyOptions.def(py::init<>())
.def_property(
"num_threads",
Expand All @@ -28,7 +28,7 @@ void BindCovariance(py::module& m) {
.def_readwrite("apply_loss_function", &Options::apply_loss_function);
MakeDataclass(PyOptions);

py::class_<ceres::Covariance>(m, "Covariance")
py::classh<ceres::Covariance>(m, "Covariance")
.def(py::init<ceres::Covariance::Options>())
.def(
"compute",
Expand Down
2 changes: 1 addition & 1 deletion _pyceres/core/crs_matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ py::tuple ConvertCRSToPyTuple(const ceres::CRSMatrix& crsMatrix) {

void BindCRSMatrix(py::module& m) {
using CRSMatrix = ceres::CRSMatrix;
py::class_<CRSMatrix> PyCRSMatrix(m, "CRSMatrix");
py::classh<CRSMatrix> PyCRSMatrix(m, "CRSMatrix");
PyCRSMatrix.def(py::init<>())
.def_readonly("num_rows", &CRSMatrix::num_rows)
.def_readonly("num_cols", &CRSMatrix::num_cols)
Expand Down
24 changes: 8 additions & 16 deletions _pyceres/core/loss_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
namespace py = pybind11;

// Trampoline class so that we can create a LossFunction in Python.
class PyLossFunction : public ceres::LossFunction {
class PyLossFunction : public ceres::LossFunction,
py::trampoline_self_life_support {
public:
/* Inherit the constructors */
using ceres::LossFunction::LossFunction;
Expand Down Expand Up @@ -86,9 +87,8 @@ std::shared_ptr<ceres::LossFunction> CreateLossFunctionFromDict(py::dict dict) {
}

void BindLossFunctions(py::module& m) {
py::class_<ceres::LossFunction,
PyLossFunction /*<--- trampoline*/,
std::shared_ptr<ceres::LossFunction>>(m, "LossFunction")
py::classh<ceres::LossFunction, PyLossFunction /*<--- trampoline*/>(
m, "LossFunction")
.def(py::init<>())
.def(py::init(&CreateLossFunctionFromDict))
.def("evaluate", [](ceres::LossFunction& self, float v) {
Expand All @@ -98,23 +98,15 @@ void BindLossFunctions(py::module& m) {
});
py::implicitly_convertible<py::dict, ceres::LossFunction>();

py::class_<ceres::TrivialLoss,
ceres::LossFunction,
std::shared_ptr<ceres::TrivialLoss>>(m, "TrivialLoss")
py::classh<ceres::TrivialLoss, ceres::LossFunction>(m, "TrivialLoss")
.def(py::init<>());

py::class_<ceres::HuberLoss,
ceres::LossFunction,
std::shared_ptr<ceres::HuberLoss>>(m, "HuberLoss")
py::classh<ceres::HuberLoss, ceres::LossFunction>(m, "HuberLoss")
.def(py::init<double>());

py::class_<ceres::SoftLOneLoss,
ceres::LossFunction,
std::shared_ptr<ceres::SoftLOneLoss>>(m, "SoftLOneLoss")
py::classh<ceres::SoftLOneLoss, ceres::LossFunction>(m, "SoftLOneLoss")
.def(py::init<double>());

py::class_<ceres::CauchyLoss,
ceres::LossFunction,
std::shared_ptr<ceres::CauchyLoss>>(m, "CauchyLoss")
py::classh<ceres::CauchyLoss, ceres::LossFunction>(m, "CauchyLoss")
.def(py::init<double>());
}
14 changes: 7 additions & 7 deletions _pyceres/core/manifold.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

namespace py = pybind11;

class PyManifold : public ceres::Manifold {
class PyManifold : public ceres::Manifold, py::trampoline_self_life_support {
/* Inherit the constructors */
using ceres::Manifold::Manifold;
bool Plus(const double* x,
Expand Down Expand Up @@ -57,23 +57,23 @@ class PyManifold : public ceres::Manifold {
};

void BindManifold(py::module& m) {
py::class_<ceres::Manifold, PyManifold /* <--- trampoline*/>(m, "Manifold")
py::classh<ceres::Manifold, PyManifold /* <--- trampoline*/>(m, "Manifold")
.def(py::init<>())
.def("ambient_size", &ceres::Manifold::AmbientSize)
.def("tangent_size", &ceres::Manifold::TangentSize);

py::class_<ceres::EuclideanManifold<ceres::DYNAMIC>, ceres::Manifold>(
py::classh<ceres::EuclideanManifold<ceres::DYNAMIC>, ceres::Manifold>(
m, "EuclideanManifold")
.def(py::init<int>());
py::class_<ceres::SubsetManifold, ceres::Manifold>(m, "SubsetManifold")
py::classh<ceres::SubsetManifold, ceres::Manifold>(m, "SubsetManifold")
.def(py::init<int, const std::vector<int>&>());
py::class_<ceres::QuaternionManifold, ceres::Manifold>(m,
py::classh<ceres::QuaternionManifold, ceres::Manifold>(m,
"QuaternionManifold")
.def(py::init<>());
py::class_<ceres::EigenQuaternionManifold, ceres::Manifold>(
py::classh<ceres::EigenQuaternionManifold, ceres::Manifold>(
m, "EigenQuaternionManifold")
.def(py::init<>());
py::class_<ceres::SphereManifold<ceres::DYNAMIC>, ceres::Manifold>(
py::classh<ceres::SphereManifold<ceres::DYNAMIC>, ceres::Manifold>(
m, "SphereManifold")
.def(py::init<int>());
}
8 changes: 4 additions & 4 deletions _pyceres/core/problem.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ std::unique_ptr<ceres::Problem> CreatePythonProblem() {

void BindProblem(py::module& m) {
using options = ceres::Problem::Options;
py::class_<ceres::Problem::Options>(m, "ProblemOptions")
py::classh<ceres::Problem::Options>(m, "ProblemOptions")
.def(py::init(&CreateProblemOptions)) // Ensures default is that
// Python manages memory
.def_readonly("cost_function_ownership",
Expand All @@ -56,7 +56,7 @@ void BindProblem(py::module& m) {
.def_readwrite("disable_all_safety_checks",
&options::disable_all_safety_checks);

py::class_<ceres::Problem::EvaluateOptions>(m, "EvaluateOptions")
py::classh<ceres::Problem::EvaluateOptions>(m, "EvaluateOptions")
.def(py::init<>())
.def(
"set_parameter_blocks",
Expand All @@ -78,9 +78,9 @@ void BindProblem(py::module& m) {
.def_readwrite("num_threads",
&ceres::Problem::EvaluateOptions::num_threads);

py::class_<ResidualBlockIDWrapper> residual_block_wrapper(m, "ResidualBlock");
py::classh<ResidualBlockIDWrapper> residual_block_wrapper(m, "ResidualBlock");

py::class_<ceres::Problem, std::shared_ptr<ceres::Problem>>(m, "Problem")
py::classh<ceres::Problem>(m, "Problem")
.def(py::init(&CreatePythonProblem))
.def(py::init<ceres::Problem::Options>())
.def("num_parameter_blocks", &ceres::Problem::NumParameterBlocks)
Expand Down
9 changes: 4 additions & 5 deletions _pyceres/core/solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ namespace py = pybind11;

void BindSolver(py::module& m) {
using IterSummary = ceres::IterationSummary;
py::class_<IterSummary, std::shared_ptr<IterSummary>> PyIterSummary(
m, "IterationSummary");
py::classh<IterSummary> PyIterSummary(m, "IterationSummary");
PyIterSummary.def(py::init<>())
.def(py::init<const IterSummary&>())
.def_readonly("iteration", &IterSummary::iteration)
Expand Down Expand Up @@ -43,7 +42,7 @@ void BindSolver(py::module& m) {
.def_readonly("cumulative_time_in_seconds",
&IterSummary::cumulative_time_in_seconds);

py::class_<ceres::IterationCallback,
py::classh<ceres::IterationCallback,
PyIterationCallback /* <--- trampoline*/>(m, "IterationCallback")
.def(py::init<>())
.def("__call__", &ceres::IterationCallback::operator());
Expand All @@ -64,7 +63,7 @@ void BindSolver(py::module& m) {
std::vector<ceres::IterationCallback*>>();

using Options = ceres::Solver::Options;
py::class_<Options, std::shared_ptr<Options>> PyOptions(m, "SolverOptions");
py::classh<Options> PyOptions(m, "SolverOptions");
PyOptions.def(py::init<>())
.def(py::init<const Options&>())
.def("IsValid", &Options::IsValid)
Expand Down Expand Up @@ -180,7 +179,7 @@ void BindSolver(py::module& m) {
MakeDataclass(PyOptions);

using Summary = ceres::Solver::Summary;
py::class_<Summary, std::shared_ptr<Summary>> PySummary(m, "SolverSummary");
py::classh<Summary> PySummary(m, "SolverSummary");
PySummary.def(py::init<>())
.def(py::init<const Summary&>())
.def("BriefReport", &Summary::BriefReport)
Expand Down
4 changes: 2 additions & 2 deletions _pyceres/helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ std::string CreateSummary(const T& self, bool write_type) {
}

template <typename T, typename... options>
void AddDefaultsToDocstrings(py::class_<T, options...> cls) {
void AddDefaultsToDocstrings(py::classh<T, options...> cls) {
auto obj = cls();
for (auto& handle : obj.attr("__dir__")()) {
const std::string attribute = py::str(handle);
Expand All @@ -233,7 +233,7 @@ void AddDefaultsToDocstrings(py::class_<T, options...> cls) {
}

template <typename T, typename... options>
void MakeDataclass(py::class_<T, options...> cls,
void MakeDataclass(py::classh<T, options...> cls,
const std::vector<std::string>& attributes = {}) {
AddDefaultsToDocstrings(cls);
if (!py::hasattr(cls, "summary")) {
Expand Down
2 changes: 1 addition & 1 deletion _pyceres/logging.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ std::pair<std::string, int> GetPythonCallFrame() {
}

void BindLogging(py::module& m) {
py::class_<Logging> PyLogging(m, "logging", py::module_local());
py::classh<Logging> PyLogging(m, "logging", py::module_local());

py::enum_<Logging::LogSeverity>(PyLogging, "Level", py::module_local())
.value("INFO", Logging::LogSeverity::GLOG_INFO)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[build-system]
requires = ["scikit-build-core>=0.3.3", "pybind11==2.13.6"]
requires = ["scikit-build-core>=0.3.3", "pybind11==3.0.0"]
build-backend = "scikit_build_core.build"

[project]
Expand Down
Loading