Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[enhancement] Refactor onedal/datatypes in preparation for dlpack support #2195

Merged
merged 17 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions onedal/basic_statistics/basic_statistics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#include "onedal/version.hpp"

#define NO_IMPORT_ARRAY // import_array called in table.cpp
#include "onedal/datatypes/data_conversion.hpp"
#include "onedal/datatypes/numpy/data_conversion.hpp"

#include <string>
#include <regex>
Expand Down Expand Up @@ -210,30 +210,30 @@ void init_partial_compute_result(py::module_& m) {
.def(py::pickle(
[](const result_t& res) {
return py::make_tuple(
py::cast<py::object>(convert_to_pyobject(res.get_partial_n_rows())),
py::cast<py::object>(convert_to_pyobject(res.get_partial_min())),
py::cast<py::object>(convert_to_pyobject(res.get_partial_max())),
py::cast<py::object>(convert_to_pyobject(res.get_partial_sum())),
py::cast<py::object>(convert_to_pyobject(res.get_partial_sum_squares())),
py::cast<py::object>(numpy::convert_to_pyobject(res.get_partial_n_rows())),
py::cast<py::object>(numpy::convert_to_pyobject(res.get_partial_min())),
py::cast<py::object>(numpy::convert_to_pyobject(res.get_partial_max())),
py::cast<py::object>(numpy::convert_to_pyobject(res.get_partial_sum())),
py::cast<py::object>(numpy::convert_to_pyobject(res.get_partial_sum_squares())),
py::cast<py::object>(
convert_to_pyobject(res.get_partial_sum_squares_centered())));
numpy::convert_to_pyobject(res.get_partial_sum_squares_centered())));
},
[](py::tuple t) {
if (t.size() != 6)
throw std::runtime_error("Invalid state!");
result_t res;
if (py::cast<int>(t[0].attr("size")) != 0)
res.set_partial_n_rows(convert_to_table(t[0]));
res.set_partial_n_rows(numpy::convert_to_table(t[0]));
if (py::cast<int>(t[1].attr("size")) != 0)
res.set_partial_min(convert_to_table(t[1]));
res.set_partial_min(numpy::convert_to_table(t[1]));
if (py::cast<int>(t[2].attr("size")) != 0)
res.set_partial_max(convert_to_table(t[2]));
res.set_partial_max(numpy::convert_to_table(t[2]));
if (py::cast<int>(t[3].attr("size")) != 0)
res.set_partial_sum(convert_to_table(t[3]));
res.set_partial_sum(numpy::convert_to_table(t[3]));
if (py::cast<int>(t[4].attr("size")) != 0)
res.set_partial_sum_squares(convert_to_table(t[4]));
res.set_partial_sum_squares(numpy::convert_to_table(t[4]));
if (py::cast<int>(t[5].attr("size")) != 0)
res.set_partial_sum_squares_centered(convert_to_table(t[5]));
res.set_partial_sum_squares_centered(numpy::convert_to_table(t[5]));

return res;
}));
Expand Down
15 changes: 8 additions & 7 deletions onedal/covariance/covariance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#include "oneapi/dal/algo/covariance.hpp"

#define NO_IMPORT_ARRAY // import_array called in table.cpp
#include "onedal/datatypes/data_conversion.hpp"
#include "onedal/datatypes/numpy/data_conversion.hpp"

#include "onedal/common.hpp"
#include "onedal/version.hpp"
Expand Down Expand Up @@ -141,20 +141,21 @@ inline void init_partial_compute_result(pybind11::module_& m) {
.def(py::pickle(
[](const result_t& res) {
return py::make_tuple(
py::cast<py::object>(convert_to_pyobject(res.get_partial_n_rows())),
py::cast<py::object>(convert_to_pyobject(res.get_partial_crossproduct())),
py::cast<py::object>(convert_to_pyobject(res.get_partial_sum())));
py::cast<py::object>(numpy::convert_to_pyobject(res.get_partial_n_rows())),
py::cast<py::object>(
numpy::convert_to_pyobject(res.get_partial_crossproduct())),
py::cast<py::object>(numpy::convert_to_pyobject(res.get_partial_sum())));
},
[](py::tuple t) {
if (t.size() != 3)
throw std::runtime_error("Invalid state!");
result_t res;
if (py::cast<int>(t[0].attr("size")) != 0)
res.set_partial_n_rows(convert_to_table(t[0]));
res.set_partial_n_rows(numpy::convert_to_table(t[0]));
if (py::cast<int>(t[1].attr("size")) != 0)
res.set_partial_crossproduct(convert_to_table(t[1]));
res.set_partial_crossproduct(numpy::convert_to_table(t[1]));
if (py::cast<int>(t[2].attr("size")) != 0)
res.set_partial_sum(convert_to_table(t[2]));
res.set_partial_sum(numpy::convert_to_table(t[2]));
return res;
}));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,27 @@ constexpr inline void apply(Op&& op, Args&&... args) {

#endif // Version check

#define SET_CTYPE_FROM_DAL_TYPE(_T, _FUNCT, _EXCEPTION) \
switch (_T) { \
case dal::data_type::float32: { \
_FUNCT(float); \
break; \
} \
case dal::data_type::float64: { \
_FUNCT(double); \
break; \
} \
case dal::data_type::int32: { \
_FUNCT(std::int32_t); \
break; \
} \
case dal::data_type::int64: { \
_FUNCT(std::int64_t); \
break; \
} \
default: _EXCEPTION; \
};

namespace oneapi::dal::python {

using supported_types_t = std::tuple<float,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
#include "oneapi/dal/table/homogen.hpp"
#include "oneapi/dal/table/detail/homogen_utils.hpp"

#include "onedal/datatypes/data_conversion.hpp"
#include "onedal/datatypes/utils/numpy_helpers.hpp"
#include "onedal/datatypes/numpy/data_conversion.hpp"
#include "onedal/datatypes/numpy/numpy_utils.hpp"
#include "onedal/version.hpp"

#if ONEDAL_VERSION <= 20230100
Expand All @@ -32,7 +32,7 @@
#include "oneapi/dal/table/csr.hpp"
#endif

namespace oneapi::dal::python {
namespace oneapi::dal::python::numpy {

#if ONEDAL_VERSION <= 20230100
typedef oneapi::dal::detail::csr_table csr_table_t;
Expand Down Expand Up @@ -152,114 +152,114 @@
return res_table;
}

dal::table convert_to_table(py::object inp_obj, py::object queue) {
dal::table res;

PyObject *obj = inp_obj.ptr();

if (obj == nullptr || obj == Py_None) {
return res;
}

#ifdef ONEDAL_DATA_PARALLEL
if (!queue.is(py::none()) && !queue.attr("sycl_device").attr("has_aspect_fp64").cast<bool>() &&
hasattr(inp_obj, "dtype")) {
// If the queue exists, doesn't have the fp64 aspect, and the data is float64
// then cast it to float32
int type = reinterpret_cast<PyArray_Descr *>(inp_obj.attr("dtype").ptr())->type_num;
if (type == NPY_DOUBLE || type == NPY_DOUBLELTR) {
PyErr_WarnEx(
PyExc_RuntimeWarning,
"Data will be converted into float32 from float64 because device does not support it",
1);
// use astype instead of PyArray_Cast in order to support scipy sparse inputs
inp_obj = inp_obj.attr("astype")(py::dtype::of<float>());
res = convert_to_table(
inp_obj); // queue will be set to none, as this check is no longer necessary
return res;
}
}
#endif // ONEDAL_DATA_PARALLEL

if (is_array(obj)) {
PyArrayObject *ary = reinterpret_cast<PyArrayObject *>(obj);

if (!PyArray_ISCARRAY_RO(ary) && !PyArray_ISFARRAY_RO(ary)) {
// NOTE: this will make a C-contiguous deep copy of the data
// this is expected to be a special case
obj = reinterpret_cast<PyObject *>(PyArray_GETCONTIGUOUS(ary));
if (obj) {
res = convert_to_table(py::cast<py::object>(obj), queue);
Py_DECREF(obj);
return res;
}
else {
throw std::invalid_argument(
"[convert_to_table] Numpy input could not be converted into onedal table.");
}
}
#define MAKE_HOMOGEN_TABLE(CType) res = convert_to_homogen_impl<CType>(ary);
SET_NPY_FEATURE(array_type(ary),
array_type_sizeof(ary),
MAKE_HOMOGEN_TABLE,
throw std::invalid_argument("Found unsupported array type"));
#undef MAKE_HOMOGEN_TABLE
}
else if (strcmp(Py_TYPE(obj)->tp_name, "csr_matrix") == 0 ||
strcmp(Py_TYPE(obj)->tp_name, "csr_array") == 0) {
PyObject *py_data = PyObject_GetAttrString(obj, "data");
PyObject *py_column_indices = PyObject_GetAttrString(obj, "indices");
PyObject *py_row_indices = PyObject_GetAttrString(obj, "indptr");

PyObject *py_shape = PyObject_GetAttrString(obj, "shape");
if (!(is_array(py_data) && is_array(py_column_indices) && is_array(py_row_indices) &&
array_numdims(py_data) == 1 && array_numdims(py_column_indices) == 1 &&
array_numdims(py_row_indices) == 1)) {
throw std::invalid_argument("[convert_to_table] Got invalid csr_matrix object.");
}
PyObject *np_data = PyArray_FROMANY(py_data, array_type(py_data), 0, 0, NPY_ARRAY_CARRAY);
PyObject *np_column_indices =
PyArray_FROMANY(py_column_indices,
NPY_UINT64,
0,
0,
NPY_ARRAY_CARRAY | NPY_ARRAY_ENSURECOPY | NPY_ARRAY_FORCECAST);
PyObject *np_row_indices =
PyArray_FROMANY(py_row_indices,
NPY_UINT64,
0,
0,
NPY_ARRAY_CARRAY | NPY_ARRAY_ENSURECOPY | NPY_ARRAY_FORCECAST);

PyObject *np_row_count = PyTuple_GetItem(py_shape, 0);
PyObject *np_column_count = PyTuple_GetItem(py_shape, 1);
if (!(np_data && np_column_indices && np_row_indices && np_row_count && np_column_count)) {
throw std::invalid_argument(
"[convert_to_table] Failed accessing csr data when converting csr_matrix.\n");
}

const std::int64_t row_count = static_cast<std::int64_t>(PyLong_AsSsize_t(np_row_count));
const std::int64_t column_count =
static_cast<std::int64_t>(PyLong_AsSsize_t(np_column_count));

#define MAKE_CSR_TABLE(CType) \
res = convert_to_csr_impl<CType>(np_data, \
np_column_indices, \
np_row_indices, \
row_count, \
column_count);
SET_NPY_FEATURE(array_type(np_data),
array_type_sizeof(np_data),
MAKE_CSR_TABLE,
throw std::invalid_argument("Found unsupported data type in csr_matrix"));
#undef MAKE_CSR_TABLE
Py_DECREF(np_column_indices);
Py_DECREF(np_row_indices);
}
else {
throw std::invalid_argument(
"[convert_to_table] Not available input format for convert Python object to onedal table.");
}

Check notice on line 262 in onedal/datatypes/numpy/data_conversion.cpp

View check run for this annotation

codefactor.io / CodeFactor

onedal/datatypes/numpy/data_conversion.cpp#L155-L262

Complex Method
return res;
}

Expand Down Expand Up @@ -432,4 +432,4 @@
return res;
}

} // namespace oneapi::dal::python
} // namespace oneapi::dal::python::numpy
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@

#include "oneapi/dal/table/common.hpp"

namespace oneapi::dal::python {
namespace oneapi::dal::python::numpy {

namespace py = pybind11;

PyObject *convert_to_pyobject(const dal::table &input);
dal::table convert_to_table(py::object inp_obj, py::object queue = py::none());

} // namespace oneapi::dal::python
} // namespace oneapi::dal::python::numpy
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
* limitations under the License.
*******************************************************************************/

#include "onedal/datatypes/utils/numpy_helpers.hpp"
#include "onedal/datatypes/numpy/numpy_utils.hpp"

namespace oneapi::dal::python {
namespace oneapi::dal::python::numpy {

template <typename Key, typename Value>
auto reverse_map(const std::map<Key, Value>& input) {
Expand Down Expand Up @@ -50,4 +50,4 @@ npy_dtype_t convert_dal_to_npy_type(dal::data_type type) {
return get_dal_to_npy_map().at(type);
}

} // namespace oneapi::dal::python
} // namespace oneapi::dal::python::numpy
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@
#define array_data(a) PyArray_DATA((PyArrayObject *)a)
#define array_size(a, i) PyArray_DIM((PyArrayObject *)a, i)

namespace oneapi::dal::python {
namespace oneapi::dal::python::numpy {

using npy_dtype_t = decltype(NPY_FLOAT);
using npy_to_dal_t = std::map<npy_dtype_t, dal::data_type>;
Expand All @@ -152,4 +152,4 @@ const dal_to_npy_t &get_dal_to_npy_map();
dal::data_type convert_npy_to_dal_type(npy_dtype_t);
npy_dtype_t convert_dal_to_npy_type(dal::data_type);

} // namespace oneapi::dal::python
} // namespace oneapi::dal::python::numpy
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,11 @@
#include "oneapi/dal/table/detail/homogen_utils.hpp"

#include "onedal/common/sycl_interfaces.hpp"
#include "onedal/datatypes/data_conversion_sua_iface.hpp"
#include "onedal/datatypes/utils/dtype_conversions.hpp"
#include "onedal/datatypes/utils/dtype_dispatcher.hpp"
#include "onedal/datatypes/utils/sua_iface_helpers.hpp"
#include "onedal/datatypes/sycl_usm/data_conversion.hpp"
#include "onedal/datatypes/sycl_usm/dtype_conversion.hpp"
#include "onedal/datatypes/sycl_usm/sycl_usm_utils.hpp"

namespace oneapi::dal::python {
namespace oneapi::dal::python::sycl_usm {

using namespace pybind11::literals;
// Please follow <https://intelpython.github.io/dpctl/latest/
Expand Down Expand Up @@ -128,7 +127,7 @@ dal::table convert_to_homogen_impl(py::object obj) {
}

// Convert oneDAL table with zero-copy by use of `__sycl_usm_array_interface__` protocol.
dal::table convert_from_sua_iface(py::object obj) {
dal::table convert_to_table(py::object obj) {
// Get `__sycl_usm_array_interface__` dictionary representing USM allocations.
auto sua_iface_dict = get_sua_interface(obj);

Expand Down Expand Up @@ -236,6 +235,6 @@ void define_sycl_usm_array_property(py::class_<dal::table>& table_obj) {
table_obj.def_property_readonly("__sycl_usm_array_interface__", &construct_sua_iface);
}

} // namespace oneapi::dal::python
} // namespace oneapi::dal::python::sycl_usm

#endif // ONEDAL_DATA_PARALLEL
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@

#include "oneapi/dal/table/common.hpp"

namespace oneapi::dal::python {
namespace oneapi::dal::python::sycl_usm {

namespace py = pybind11;

// Convert oneDAL table with zero-copy by use of `__sycl_usm_array_interface__` protocol.
dal::table convert_from_sua_iface(py::object obj);
dal::table convert_to_table(py::object obj);

// Create a dictionary for `__sycl_usm_array_interface__` protocol from oneDAL table properties.
py::dict construct_sua_iface(const dal::table& input);
Expand All @@ -37,4 +37,4 @@ py::dict construct_sua_iface(const dal::table& input);
// USM allocations.
void define_sycl_usm_array_property(py::class_<dal::table>& t);

} // namespace oneapi::dal::python
} // namespace oneapi::dal::python::sycl_usm
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
#include "oneapi/dal/common.hpp"
#include "oneapi/dal/detail/common.hpp"

#include "onedal/datatypes/utils/dtype_conversions.hpp"
#include "onedal/datatypes/utils/dtype_dispatcher.hpp"
#include "onedal/datatypes/sycl_usm/dtype_conversion.hpp"
#include "onedal/datatypes/dtype_dispatcher.hpp"

namespace oneapi::dal::python {
namespace oneapi::dal::python::sycl_usm {

using fwd_map_t = std::unordered_map<std::string, dal::data_type>;
using inv_map_t = std::unordered_map<dal::data_type, std::string>;
Expand Down Expand Up @@ -139,4 +139,4 @@ std::string convert_dal_to_sua_type(dal::data_type dtype) {
return get_inv_map().at(dtype);
}

} // namespace oneapi::dal::python
} // namespace oneapi::dal::python::sycl_usm
33 changes: 33 additions & 0 deletions onedal/datatypes/sycl_usm/dtype_conversion.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*******************************************************************************
* Copyright 2024 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

#pragma once

#include <string>

#include <pybind11/pybind11.h>

#include "oneapi/dal/common.hpp"
#include "onedal/datatypes/dtype_dispatcher.hpp"

namespace py = pybind11;

namespace oneapi::dal::python::sycl_usm {

dal::data_type convert_sua_to_dal_type(std::string dtype);
std::string convert_dal_to_sua_type(dal::data_type dtype);

} // namespace oneapi::dal::python::sycl_usm
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,8 @@
#include "oneapi/dal/table/detail/homogen_utils.hpp"

#include "onedal/common/sycl_interfaces.hpp"
#include "onedal/datatypes/data_conversion_sua_iface.hpp"
#include "onedal/datatypes/utils/dtype_conversions.hpp"
#include "onedal/datatypes/utils/dtype_dispatcher.hpp"
#include "onedal/datatypes/sycl_usm/data_conversion.hpp"
#include "onedal/datatypes/sycl_usm/dtype_conversion.hpp"

/* __sycl_usm_array_interface__
*
Expand All @@ -53,7 +52,7 @@
* api_reference/dpctl/sycl_usm_array_interface.html#sycl-usm-array-interface-attribute>
*/

namespace oneapi::dal::python {
namespace oneapi::dal::python::sycl_usm {

// Convert a string encoding elemental data type of the array to oneDAL homogen table data type.
dal::data_type get_sua_dtype(const py::dict& sua) {
Expand Down Expand Up @@ -197,6 +196,6 @@ py::tuple get_npy_strides(const dal::data_layout& data_layout,
return strides;
}

} // namespace oneapi::dal::python
} // namespace oneapi::dal::python::sycl_usm

#endif // ONEDAL_DATA_PARALLEL
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,10 @@
#include "oneapi/dal/table/detail/homogen_utils.hpp"

#include "onedal/common/sycl_interfaces.hpp"
#include "onedal/datatypes/data_conversion_sua_iface.hpp"
#include "onedal/datatypes/utils/dtype_conversions.hpp"
#include "onedal/datatypes/utils/dtype_dispatcher.hpp"
#include "onedal/datatypes/sycl_usm/data_conversion.hpp"
#include "onedal/datatypes/sycl_usm/dtype_conversion.hpp"

namespace oneapi::dal::python {
namespace oneapi::dal::python::sycl_usm {

dal::data_type get_sua_dtype(const py::dict& sua);

Expand Down Expand Up @@ -62,6 +61,6 @@ py::tuple get_npy_strides(const dal::data_layout& data_layout,
npy_intp row_count,
npy_intp column_count);

} // namespace oneapi::dal::python
} // namespace oneapi::dal::python::sycl_usm

#endif // ONEDAL_DATA_PARALLEL
16 changes: 8 additions & 8 deletions onedal/datatypes/table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
#include "oneapi/dal/table/homogen.hpp"

#ifdef ONEDAL_DATA_PARALLEL
#include "onedal/datatypes/data_conversion_sua_iface.hpp"
#include "onedal/datatypes/sycl_usm/data_conversion.hpp"
#endif // ONEDAL_DATA_PARALLEL

#include "onedal/datatypes/data_conversion.hpp"
#include "onedal/datatypes/utils/numpy_helpers.hpp"
#include "onedal/datatypes/numpy/data_conversion.hpp"
#include "onedal/datatypes/numpy/numpy_utils.hpp"
#include "onedal/common/pybind11_helpers.hpp"
#include "onedal/version.hpp"

Expand Down Expand Up @@ -74,25 +74,25 @@ ONEDAL_PY_INIT_MODULE(table) {
});
table_obj.def_property_readonly("dtype", [](const table& t) {
// returns a numpy dtype, even if source was not from numpy
return py::dtype(convert_dal_to_npy_type(t.get_metadata().get_data_type(0)));
return py::dtype(numpy::convert_dal_to_npy_type(t.get_metadata().get_data_type(0)));
});

#ifdef ONEDAL_DATA_PARALLEL
define_sycl_usm_array_property(table_obj);
sycl_usm::define_sycl_usm_array_property(table_obj);
#endif // ONEDAL_DATA_PARALLEL

m.def("to_table", [](py::object obj, py::object queue) {
#ifdef ONEDAL_DATA_PARALLEL
if (py::hasattr(obj, "__sycl_usm_array_interface__")) {
return convert_from_sua_iface(obj);
return sycl_usm::convert_to_table(obj);
}
#endif // ONEDAL_DATA_PARALLEL

return convert_to_table(obj, queue);
return numpy::convert_to_table(obj, queue);
});

m.def("from_table", [](const dal::table& t) -> py::handle {
auto* obj_ptr = convert_to_pyobject(t);
auto* obj_ptr = numpy::convert_to_pyobject(t);
return obj_ptr;
});
}
Expand Down
Loading
Loading