diff --git a/cpp/daal/BUILD b/cpp/daal/BUILD
index 0cbf8a50316..1b4cbe4e9ec 100644
--- a/cpp/daal/BUILD
+++ b/cpp/daal/BUILD
@@ -223,6 +223,7 @@ daal_algorithms(
"stump",
"svd",
"svm",
+ "spectral_embedding",
"weak_learner/inner",
],
)
diff --git a/cpp/daal/src/algorithms/spectral_embedding/BUILD b/cpp/daal/src/algorithms/spectral_embedding/BUILD
new file mode 100644
index 00000000000..5fdba550542
--- /dev/null
+++ b/cpp/daal/src/algorithms/spectral_embedding/BUILD
@@ -0,0 +1,11 @@
+package(default_visibility = ["//visibility:public"])
+load("@onedal//dev/bazel:daal.bzl", "daal_module")
+
+daal_module(
+ name = "kernel",
+ auto = True,
+ deps = [
+ "@onedal//cpp/daal:core",
+ "@onedal//cpp/daal/src/algorithms/cosdistance:kernel",
+ ],
+)
diff --git a/cpp/daal/src/algorithms/spectral_embedding/spectral_embedding_default_dense_fpt_cpu.cpp b/cpp/daal/src/algorithms/spectral_embedding/spectral_embedding_default_dense_fpt_cpu.cpp
new file mode 100644
index 00000000000..684de1f8657
--- /dev/null
+++ b/cpp/daal/src/algorithms/spectral_embedding/spectral_embedding_default_dense_fpt_cpu.cpp
@@ -0,0 +1,39 @@
+/* file: spectral_embedding_default_dense_fpt_cpu.cpp */
+/*******************************************************************************
+* Copyright contributors to the oneDAL project
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+/*
+//++
+// Instantiation of CPU-specific spectral_embedding kernel implementations
+//--
+*/
+
+#include "spectral_embedding_kernel.h"
+#include "spectral_embedding_default_dense_impl.i"
+
+namespace daal
+{
+namespace algorithms
+{
+namespace spectral_embedding
+{
+namespace internal
+{
+template class DAAL_EXPORT SpectralEmbeddingKernel;
+} // namespace internal
+} // namespace spectral_embedding
+} // namespace algorithms
+} // namespace daal
diff --git a/cpp/daal/src/algorithms/spectral_embedding/spectral_embedding_default_dense_impl.i b/cpp/daal/src/algorithms/spectral_embedding/spectral_embedding_default_dense_impl.i
new file mode 100644
index 00000000000..5f3d7981c6a
--- /dev/null
+++ b/cpp/daal/src/algorithms/spectral_embedding/spectral_embedding_default_dense_impl.i
@@ -0,0 +1,214 @@
+/* file: spectral_embedding_default_dense_impl.i */
+/*******************************************************************************
+* Copyright contributors to the oneDAL project
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+/*
+//++
+// Implementation of cosine distance.
+//--
+*/
+
+#include "services/daal_defines.h"
+#include "src/externals/service_math.h"
+#include "src/externals/service_blas.h"
+#include "src/threading/threading.h"
+#include "src/algorithms/service_error_handling.h"
+#include "src/data_management/service_numeric_table.h"
+#include "src/algorithms/cosdistance/cosdistance_kernel.h"
+#include "src/externals/service_lapack.h"
+#include
+
+using namespace daal::internal;
+
+namespace daal
+{
+namespace algorithms
+{
+namespace spectral_embedding
+{
+namespace internal
+{
+
+template
+services::Status computeEigenvectorsInplace(size_t nFeatures, algorithmFPType * eigenvectors, algorithmFPType * eigenvalues)
+{
+ char jobz = 'V';
+ char uplo = 'U';
+
+ DAAL_INT lwork = 2 * nFeatures * nFeatures + 6 * nFeatures + 1;
+ DAAL_INT liwork = 5 * nFeatures + 3;
+ DAAL_INT info;
+
+ TArray work(lwork);
+ TArray iwork(liwork);
+ DAAL_CHECK_MALLOC(work.get() && iwork.get());
+
+ LapackInst::xsyevd(&jobz, &uplo, (DAAL_INT *)(&nFeatures), eigenvectors, (DAAL_INT *)(&nFeatures), eigenvalues, work.get(),
+ &lwork, iwork.get(), &liwork, &info);
+ if (info != 0) return services::Status(services::ErrorPCAFailedToComputeCorrelationEigenvalues); // CHANGE ERROR STATUS
+ return services::Status();
+}
+
+/**
+ * \brief Kernel for Spectral Embedding calculation
+ */
+template
+services::Status SpectralEmbeddingKernel::compute(const NumericTable * xTable, NumericTable * embeddingTable,
+ NumericTable * eigenTable, const KernelParameter & par)
+{
+ services::Status status;
+ // std::cout << "inside DAAL kernel" << std::endl;
+ // std::cout << "Params: " << par.numberOfEmbeddings << " " << par.numberOfNeighbors << std::endl;
+ size_t k = par.numberOfEmbeddings;
+ size_t filtNum = par.numberOfNeighbors + 1;
+ size_t n = xTable->getNumberOfRows(); /* Number of input feature vectors */
+
+ SharedPtr > tmpMatrixPtr =
+ HomogenNumericTable::create(n, n, NumericTable::doAllocate, &status);
+
+ DAAL_CHECK_STATUS_VAR(status);
+ NumericTable * covOutput = tmpMatrixPtr.get();
+ NumericTable * a0 = const_cast(xTable);
+ NumericTable * eigenvalues = const_cast(eigenTable);
+
+ // Compute cosine distances matrix
+ {
+ auto cosDistanceKernel = cosine_distance::internal::DistanceKernel();
+ DAAL_CHECK_STATUS(status, cosDistanceKernel.compute(0, &a0, 0, &covOutput, nullptr));
+ }
+
+ WriteRows xMatrix(covOutput, 0, n);
+ DAAL_CHECK_BLOCK_STATUS(xMatrix);
+ algorithmFPType * x = xMatrix.get();
+
+ size_t lcnt, rcnt, cnt;
+ algorithmFPType L, R, M;
+ // Use binary search to find such d that the number of verticies having distance <= d is filtNum
+ const size_t binarySearchIterNum = 20;
+ // TODO: add parallel_for
+ for (size_t i = 0; i < n; ++i)
+ {
+ L = 0; // min possible cos distance
+ R = 2; // max possible cos distance
+ lcnt = 0; // number of elements with cos distance <= L
+ rcnt = n; // number of elements with cos distance <= R
+ for (size_t ij = 0; ij < binarySearchIterNum; ++ij)
+ {
+ M = (L + R) / 2;
+ cnt = 0;
+ // Calculate the number of elements in the row with value <= M
+ for (size_t j = 0; j < n; ++j)
+ {
+ if (x[i * n + j] <= M)
+ {
+ cnt++;
+ }
+ }
+ if (cnt < filtNum)
+ {
+ L = M;
+ lcnt = cnt;
+ }
+ else
+ {
+ R = M;
+ rcnt = cnt;
+ }
+ // distance threshold is found
+ if (rcnt == filtNum)
+ {
+ break;
+ }
+ }
+ // create edges for the closest neighbors
+ for (size_t j = 0; j < n; ++j)
+ {
+ if (x[i * n + j] <= R)
+ {
+ x[i * n + j] = 1.0;
+ }
+ else
+ {
+ x[i * n + j] = 0.0;
+ }
+ }
+ // fill the diagonal of matrix with zeros
+ x[i * n + i] = 0;
+ }
+
+ // Create Laplassian matrix
+ for (size_t i = 0; i < n; ++i)
+ {
+ for (size_t j = 0; j < i; ++j)
+ {
+ algorithmFPType val = (x[i * n + j] + x[j * n + i]) / 2;
+ x[i * n + j] = -val;
+ x[j * n + i] = -val;
+ x[i * n + i] += val;
+ x[j * n + j] += val;
+ }
+ }
+
+ // std::cout << "Laplacian matrix" << std::endl;
+ // for (int i = 0; i < n; ++i) {
+ // for (int j = 0; j < n; ++j) {
+ // std::cout << x[i * n + j] << " ";
+ // }
+ // std::cout << std::endl;
+ // }
+ // std::cout << "------" << std::endl;
+
+ // Find the eigen vectors and eigne values of the matix
+ //TArray eigenvalues(n);
+ //DAAL_CHECK_MALLOC(eigenvalues.get());
+ WriteRows eigenValuesBlock(eigenvalues, 0, n);
+ DAAL_CHECK_BLOCK_STATUS(eigenValuesBlock);
+ algorithmFPType * eigenValuesPtr = eigenValuesBlock.get();
+
+ status |= computeEigenvectorsInplace(n, x, eigenValuesPtr);
+ DAAL_CHECK_STATUS_VAR(status);
+
+ // std::cout << "Eigen vectors: " << std::endl;
+ // for (int i = 0; i < n; ++i) {
+ // for (int j = 0; j < n; ++j) {
+ // std::cout << x[i * n + j] << " ";
+ // }
+ // std::cout << std::endl;
+ // }
+
+ // Fill the output matrix with eigen vectors corresponding to the smallest eigen values
+ WriteOnlyRows embedMatrix(embeddingTable, 0, n);
+ DAAL_CHECK_BLOCK_STATUS(embedMatrix);
+ algorithmFPType * embed = embedMatrix.get();
+
+ for (int i = 0; i < k; ++i)
+ {
+ for (int j = 0; j < n; ++j)
+ {
+ embed[j * k + i] = x[i * n + j];
+ }
+ }
+
+ return status;
+}
+
+} // namespace internal
+
+} // namespace spectral_embedding
+
+} // namespace algorithms
+
+} // namespace daal
diff --git a/cpp/daal/src/algorithms/spectral_embedding/spectral_embedding_kernel.h b/cpp/daal/src/algorithms/spectral_embedding/spectral_embedding_kernel.h
new file mode 100644
index 00000000000..0248ac4b288
--- /dev/null
+++ b/cpp/daal/src/algorithms/spectral_embedding/spectral_embedding_kernel.h
@@ -0,0 +1,66 @@
+/* file: spectral_embedding_kernel.h */
+/*******************************************************************************
+* Copyright contributors to the oneDAL project
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+/*
+//++
+// Declaration of template structs that calculate SVM Training functions.
+//--
+*/
+
+#ifndef __SPECTRAL_EMBEDDING_KERNEL_H__
+#define __SPECTRAL_EMBEDDING_KERNEL_H__
+
+#include "data_management/data/numeric_table.h"
+#include "services/daal_defines.h"
+#include "src/algorithms/kernel.h"
+
+namespace daal
+{
+namespace algorithms
+{
+namespace spectral_embedding
+{
+
+enum Method
+{
+ defaultDense = 0
+};
+
+namespace internal
+{
+
+using namespace daal::data_management;
+using namespace daal::services;
+
+struct KernelParameter : daal::algorithms::Parameter
+{
+ size_t numberOfEmbeddings = 1;
+ size_t numberOfNeighbors = 1;
+};
+
+template
+struct SpectralEmbeddingKernel : public Kernel
+{
+ services::Status compute(const NumericTable * xTable, NumericTable * embeddingTable, NumericTable * eigenTable, const KernelParameter & par);
+};
+
+} // namespace internal
+} // namespace spectral_embedding
+} // namespace algorithms
+} // namespace daal
+
+#endif
diff --git a/cpp/oneapi/dal/algo/BUILD b/cpp/oneapi/dal/algo/BUILD
index e93804d2e7e..a7981639dfc 100644
--- a/cpp/oneapi/dal/algo/BUILD
+++ b/cpp/oneapi/dal/algo/BUILD
@@ -34,6 +34,7 @@ ALGOS = [
"rbf_kernel",
"sigmoid_kernel",
"shortest_paths",
+ "spectral_embedding",
"subgraph_isomorphism",
"svm",
"triangle_counting",
diff --git a/cpp/oneapi/dal/algo/spectral_embedding.hpp b/cpp/oneapi/dal/algo/spectral_embedding.hpp
new file mode 100644
index 00000000000..3606f826809
--- /dev/null
+++ b/cpp/oneapi/dal/algo/spectral_embedding.hpp
@@ -0,0 +1,19 @@
+/*******************************************************************************
+* Copyright 2024 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#pragma once
+
+#include "oneapi/dal/algo/spectral_embedding/compute.hpp"
diff --git a/cpp/oneapi/dal/algo/spectral_embedding/BUILD b/cpp/oneapi/dal/algo/spectral_embedding/BUILD
new file mode 100644
index 00000000000..88640f87665
--- /dev/null
+++ b/cpp/oneapi/dal/algo/spectral_embedding/BUILD
@@ -0,0 +1,38 @@
+package(default_visibility = ["//visibility:public"])
+load("@onedal//dev/bazel:dal.bzl",
+ "dal_module",
+ "dal_test_suite",
+)
+
+dal_module(
+ name = "spectral_embedding",
+ auto = True,
+ dal_deps = [
+ "@onedal//cpp/oneapi/dal:core",
+ "@onedal//cpp/oneapi/dal/backend/primitives:common",
+ ],
+ extra_deps = [
+ "@onedal//cpp/daal/src/algorithms/spectral_embedding:kernel",
+ ]
+)
+
+dal_test_suite(
+ name = "interface_tests",
+ framework = "catch2",
+ hdrs = glob([
+ "test/*.hpp",
+ ]),
+ srcs = glob([
+ "test/*.cpp",
+ ]),
+ dal_deps = [
+ ":spectral_embedding",
+ ],
+)
+
+dal_test_suite(
+ name = "tests",
+ tests = [
+ ":interface_tests",
+ ],
+)
diff --git a/cpp/oneapi/dal/algo/spectral_embedding/backend/cpu/compute_kernel.cpp b/cpp/oneapi/dal/algo/spectral_embedding/backend/cpu/compute_kernel.cpp
new file mode 100644
index 00000000000..04dd7e92609
--- /dev/null
+++ b/cpp/oneapi/dal/algo/spectral_embedding/backend/cpu/compute_kernel.cpp
@@ -0,0 +1,110 @@
+/*******************************************************************************
+* Copyright contributors to the oneDAL project
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "daal/src/algorithms/spectral_embedding/spectral_embedding_kernel.h"
+
+#include "oneapi/dal/algo/spectral_embedding/backend/cpu/compute_kernel.hpp"
+#include "oneapi/dal/backend/interop/common.hpp"
+#include "oneapi/dal/backend/interop/error_converter.hpp"
+#include "oneapi/dal/backend/interop/table_conversion.hpp"
+
+#include "oneapi/dal/table/row_accessor.hpp"
+#include "oneapi/dal/detail/error_messages.hpp"
+#include
+
+namespace oneapi::dal::spectral_embedding::backend {
+
+using dal::backend::context_cpu;
+using descriptor_t = detail::descriptor_base;
+
+namespace sp_emb = daal::algorithms::spectral_embedding;
+
+template
+using daal_sp_emb_kernel_t =
+ sp_emb::internal::SpectralEmbeddingKernel;
+
+using parameter_t = sp_emb::internal::KernelParameter;
+
+namespace interop = oneapi::dal::backend::interop;
+
+template
+static compute_result call_daal_kernel(const context_cpu& ctx,
+ const descriptor_t& desc,
+ const table& data) {
+ const auto daal_data = interop::convert_to_daal_table(data);
+
+ // const std::int64_t p = data.get_column_count();
+ const std::int64_t n = data.get_row_count();
+ std::int64_t k = desc.get_component_count();
+
+ // std::cout << "inside oneDAL kernel: " << n << " " << p << std::endl;
+
+ auto result = compute_result{}.set_result_options(desc.get_result_options());
+
+ if (result.get_result_options().test(result_options::embedding) ||
+ result.get_result_options().test(result_options::eigen_values)) {
+ daal::services::SharedPtr daal_input, daal_output, daal_eigen_vals;
+ array arr_output, arr_eigen_vals;
+ arr_output = array::empty(n * k);
+ arr_eigen_vals = array::empty(n);
+ daal_output = interop::convert_to_daal_homogen_table(arr_output, n, k);
+ daal_eigen_vals = interop::convert_to_daal_homogen_table(arr_eigen_vals, n, 1);
+ parameter_t daal_param;
+
+ daal_param.numberOfEmbeddings = k;
+ if (desc.get_neighbor_count() < 0) {
+ daal_param.numberOfNeighbors = n - 1;
+ }
+ else {
+ daal_param.numberOfNeighbors = desc.get_neighbor_count();
+ }
+ interop::status_to_exception(
+ interop::call_daal_kernel(ctx,
+ daal_data.get(),
+ daal_output.get(),
+ daal_eigen_vals.get(),
+ daal_param));
+ if (result.get_result_options().test(result_options::embedding)) {
+ result.set_embedding(homogen_table::wrap(arr_output, n, k));
+ }
+ if (result.get_result_options().test(result_options::eigen_values)) {
+ result.set_eigen_values(homogen_table::wrap(arr_eigen_vals, n, 1));
+ }
+ }
+
+ return result;
+}
+
+template
+static compute_result compute(const context_cpu& ctx,
+ const descriptor_t& desc,
+ const compute_input& input) {
+ return call_daal_kernel(ctx, desc, input.get_data());
+}
+
+template
+struct compute_kernel_cpu {
+ compute_result operator()(const context_cpu& ctx,
+ const descriptor_t& desc,
+ const compute_input& input) const {
+ return compute(ctx, desc, input);
+ }
+};
+
+template struct compute_kernel_cpu;
+template struct compute_kernel_cpu;
+
+} // namespace oneapi::dal::spectral_embedding::backend
diff --git a/cpp/oneapi/dal/algo/spectral_embedding/backend/cpu/compute_kernel.hpp b/cpp/oneapi/dal/algo/spectral_embedding/backend/cpu/compute_kernel.hpp
new file mode 100644
index 00000000000..855903d883c
--- /dev/null
+++ b/cpp/oneapi/dal/algo/spectral_embedding/backend/cpu/compute_kernel.hpp
@@ -0,0 +1,31 @@
+/*******************************************************************************
+* Copyright contributors to the oneDAL project
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#pragma once
+
+#include "oneapi/dal/algo/spectral_embedding/compute_types.hpp"
+#include "oneapi/dal/backend/dispatcher.hpp"
+
+namespace oneapi::dal::spectral_embedding::backend {
+
+template
+struct compute_kernel_cpu {
+ compute_result operator()(const dal::backend::context_cpu& ctx,
+ const detail::descriptor_base& params,
+ const compute_input& input) const;
+};
+
+} // namespace oneapi::dal::spectral_embedding::backend
diff --git a/cpp/oneapi/dal/algo/spectral_embedding/backend/gpu/compute_kernel.hpp b/cpp/oneapi/dal/algo/spectral_embedding/backend/gpu/compute_kernel.hpp
new file mode 100644
index 00000000000..b7f4fdc63d8
--- /dev/null
+++ b/cpp/oneapi/dal/algo/spectral_embedding/backend/gpu/compute_kernel.hpp
@@ -0,0 +1,31 @@
+/*******************************************************************************
+* Copyright contributors to the oneDAL project
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#pragma once
+
+#include "oneapi/dal/algo/spectral_embedding/compute_types.hpp"
+#include "oneapi/dal/backend/dispatcher.hpp"
+
+namespace oneapi::dal::spectral_embedding::backend {
+
+template
+struct compute_kernel_gpu {
+ compute_result operator()(const dal::backend::context_gpu& ctx,
+ const detail::descriptor_base& params,
+ const compute_input& input) const;
+};
+
+} // namespace oneapi::dal::spectral_embedding::backend
diff --git a/cpp/oneapi/dal/algo/spectral_embedding/backend/gpu/compute_kernel_dpc.cpp b/cpp/oneapi/dal/algo/spectral_embedding/backend/gpu/compute_kernel_dpc.cpp
new file mode 100644
index 00000000000..7cfe7dbf8b5
--- /dev/null
+++ b/cpp/oneapi/dal/algo/spectral_embedding/backend/gpu/compute_kernel_dpc.cpp
@@ -0,0 +1,36 @@
+/*******************************************************************************
+* Copyright contributors to the oneDAL project
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "oneapi/dal/algo/spectral_embedding/backend/gpu/compute_kernel.hpp"
+#include "oneapi/dal/exceptions.hpp"
+
+namespace oneapi::dal::spectral_embedding::backend {
+
+template
+struct compute_kernel_gpu {
+ compute_result operator()(const dal::backend::context_gpu& ctx,
+ const detail::descriptor_base& desc,
+ const compute_input& input) const {
+ // CHANGE ERROR MESSAGE
+ throw unimplemented(
+ dal::detail::error_messages::sp_emb_dense_batch_method_is_not_implemented_for_gpu());
+ }
+};
+
+template struct compute_kernel_gpu;
+template struct compute_kernel_gpu;
+
+} // namespace oneapi::dal::spectral_embedding::backend
diff --git a/cpp/oneapi/dal/algo/spectral_embedding/common.cpp b/cpp/oneapi/dal/algo/spectral_embedding/common.cpp
new file mode 100644
index 00000000000..64f279f56c5
--- /dev/null
+++ b/cpp/oneapi/dal/algo/spectral_embedding/common.cpp
@@ -0,0 +1,94 @@
+/*******************************************************************************
+* Copyright contributors to the oneDAL project
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "oneapi/dal/algo/spectral_embedding/common.hpp"
+#include "oneapi/dal/exceptions.hpp"
+
+namespace oneapi::dal::spectral_embedding::detail {
+
+result_option_id get_embedding_id() {
+ return result_option_id{ result_option_id::make_by_index(0) };
+}
+
+result_option_id get_eigen_values_id() {
+ return result_option_id{ result_option_id::make_by_index(1) };
+}
+
+template
+result_option_id get_default_result_options() {
+ return result_option_id{};
+}
+
+template <>
+result_option_id get_default_result_options() {
+ return get_embedding_id();
+}
+
+namespace v1 {
+
+template
+class descriptor_impl : public base {
+public:
+ explicit descriptor_impl() {}
+
+ std::int64_t component_count = 0;
+ std::int64_t neighbor_count = -1;
+
+ result_option_id result_options = get_default_result_options();
+};
+
+template
+descriptor_base::descriptor_base() : impl_(new descriptor_impl{}) {}
+
+template
+std::int64_t descriptor_base::get_component_count() const {
+ return impl_->component_count;
+}
+
+template
+std::int64_t descriptor_base::get_neighbor_count() const {
+ return impl_->neighbor_count;
+}
+
+template
+void descriptor_base::set_component_count_impl(std::int64_t component_count) {
+ impl_->component_count = component_count;
+}
+
+template
+void descriptor_base::set_neighbor_count_impl(std::int64_t neighbor_count) {
+ impl_->neighbor_count = neighbor_count;
+}
+
+template
+result_option_id descriptor_base::get_result_options() const {
+ return impl_->result_options;
+}
+
+template
+void descriptor_base::set_result_options_impl(const result_option_id& value) {
+ using msg = dal::detail::error_messages;
+ if (!bool(value)) {
+ throw domain_error(msg::empty_set_of_result_options());
+ }
+ impl_->result_options = value;
+}
+
+template class ONEDAL_EXPORT descriptor_base;
+
+} // namespace v1
+
+} // namespace oneapi::dal::spectral_embedding::detail
diff --git a/cpp/oneapi/dal/algo/spectral_embedding/common.hpp b/cpp/oneapi/dal/algo/spectral_embedding/common.hpp
new file mode 100644
index 00000000000..8ceb5fda692
--- /dev/null
+++ b/cpp/oneapi/dal/algo/spectral_embedding/common.hpp
@@ -0,0 +1,201 @@
+/*******************************************************************************
+* Copyright contributors to the oneDAL project
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#pragma once
+
+#include "oneapi/dal/util/result_option_id.hpp"
+#include "oneapi/dal/detail/common.hpp"
+#include "oneapi/dal/detail/serialization.hpp"
+#include "oneapi/dal/table/common.hpp"
+#include "oneapi/dal/common.hpp"
+
+namespace oneapi::dal::spectral_embedding {
+
+namespace task {
+namespace v1 {
+
+/// Tag-type that parameterizes entities that are used to compute statistics.
+struct compute {};
+
+/// Alias tag-type for compute task.
+using by_default = compute;
+} // namespace v1
+
+using v1::compute;
+using v1::by_default;
+
+} // namespace task
+
+namespace method {
+namespace v1 {
+
+/// Tag-type that denotes dense_batch computational method.
+struct dense_batch {};
+
+/// Alias tag-type for the dense_batch computational method.
+using by_default = dense_batch;
+
+} // namespace v1
+
+using v1::dense_batch;
+using v1::by_default;
+
+} // namespace method
+
+/// Represents result option flag
+/// Behaves like a regular :expr`enum`.
+class result_option_id : public result_option_id_base {
+public:
+ constexpr result_option_id() = default;
+ constexpr explicit result_option_id(const result_option_id_base& base)
+ : result_option_id_base{ base } {}
+};
+
+namespace detail {
+
+ONEDAL_EXPORT result_option_id get_embedding_id();
+ONEDAL_EXPORT result_option_id get_eigen_values_id();
+
+} // namespace detail
+
+/// Result options are used to define
+/// what should algorithm return
+namespace result_options {
+
+/// Return spectral embedding
+const inline auto embedding = detail::get_embedding_id();
+
+/// Return eigen values of Laplassian matrix
+const inline auto eigen_values = detail::get_eigen_values_id();
+
+} // namespace result_options
+
+namespace detail {
+
+namespace v1 {
+
+struct descriptor_tag {};
+
+template
+class descriptor_impl;
+
+template
+constexpr bool is_valid_float_v = dal::detail::is_one_of_v;
+
+template
+constexpr bool is_valid_method_v = dal::detail::is_one_of_v;
+
+template
+constexpr bool is_valid_task_v = dal::detail::is_one_of_v;
+
+template
+class descriptor_base : public base {
+ static_assert(is_valid_task_v);
+
+public:
+ using tag_t = descriptor_tag;
+ using float_t = float;
+ using method_t = method::by_default;
+ using task_t = Task;
+
+ descriptor_base();
+
+ std::int64_t get_component_count() const;
+ std::int64_t get_neighbor_count() const;
+ result_option_id get_result_options() const;
+
+protected:
+ void set_component_count_impl(std::int64_t component_count);
+ void set_neighbor_count_impl(std::int64_t neighbor_count);
+ void set_result_options_impl(const result_option_id& value);
+
+private:
+ dal::detail::pimpl> impl_;
+};
+
+} // namespace v1
+
+using v1::descriptor_tag;
+using v1::descriptor_impl;
+using v1::descriptor_base;
+
+using v1::is_valid_float_v;
+using v1::is_valid_method_v;
+using v1::is_valid_task_v;
+
+} // namespace detail
+
+namespace v1 {
+
+/// @tparam Float The floating-point type that the algorithm uses for
+/// intermediate computations. Can be :expr:`float` or
+/// :expr:`double`.
+/// @tparam Method Tag-type that specifies an implementation of algorithm. Can
+/// be :expr:`method::dense_batch`.
+/// @tparam Task Tag-type that specifies the type of the problem to solve. Can
+/// be :expr:`task::compute`.
+
+template
+class descriptor : public detail::descriptor_base {
+ static_assert(detail::is_valid_float_v);
+ static_assert(detail::is_valid_method_v);
+ static_assert(detail::is_valid_task_v);
+ using base_t = detail::descriptor_base;
+
+public:
+ using float_t = Float;
+ using method_t = Method;
+ using task_t = Task;
+
+ /// Creates a new instance of the class with the default property values.
+ explicit descriptor() : base_t() {}
+
+ std::int64_t get_component_count() const {
+ return base_t::get_component_count();
+ }
+
+ std::int64_t get_neighbor_count() const {
+ return base_t::get_neighbor_count();
+ }
+
+ auto& set_component_count(std::int64_t component_count) {
+ base_t::set_component_count_impl(component_count);
+ return *this;
+ }
+
+ auto& set_neighbor_count(std::int64_t neighbor_count) {
+ base_t::set_neighbor_count_impl(neighbor_count);
+ return *this;
+ }
+
+ /// Choose which results should be computed and returned.
+ result_option_id get_result_options() const {
+ return base_t::get_result_options();
+ }
+
+ auto& set_result_options(const result_option_id& value) {
+ base_t::set_result_options_impl(value);
+ return *this;
+ }
+};
+
+} // namespace v1
+
+using v1::descriptor;
+
+} // namespace oneapi::dal::spectral_embedding
diff --git a/cpp/oneapi/dal/algo/spectral_embedding/compute.hpp b/cpp/oneapi/dal/algo/spectral_embedding/compute.hpp
new file mode 100644
index 00000000000..ceb49cc42fd
--- /dev/null
+++ b/cpp/oneapi/dal/algo/spectral_embedding/compute.hpp
@@ -0,0 +1,31 @@
+/*******************************************************************************
+* Copyright contributors to the oneDAL project
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#pragma once
+
+#include "oneapi/dal/algo/spectral_embedding/compute_types.hpp"
+#include "oneapi/dal/algo/spectral_embedding/detail/compute_ops.hpp"
+#include "oneapi/dal/compute.hpp"
+
+namespace oneapi::dal::detail {
+namespace v1 {
+
+template
+struct compute_ops
+ : dal::spectral_embedding::detail::compute_ops {};
+
+} // namespace v1
+} // namespace oneapi::dal::detail
diff --git a/cpp/oneapi/dal/algo/spectral_embedding/compute_types.cpp b/cpp/oneapi/dal/algo/spectral_embedding/compute_types.cpp
new file mode 100644
index 00000000000..842ffd1134e
--- /dev/null
+++ b/cpp/oneapi/dal/algo/spectral_embedding/compute_types.cpp
@@ -0,0 +1,109 @@
+/*******************************************************************************
+* Copyright contributors to the oneDAL project
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "oneapi/dal/algo/spectral_embedding/compute_types.hpp"
+#include "oneapi/dal/detail/common.hpp"
+#include "oneapi/dal/exceptions.hpp"
+
+namespace oneapi::dal::spectral_embedding {
+
+template
+class detail::v1::compute_input_impl : public base {
+public:
+ compute_input_impl(const table& data) : data(data) {}
+ table data;
+};
+
+template
+class detail::v1::compute_result_impl : public base {
+public:
+ table embedding;
+ table eigen_values;
+ result_option_id options;
+};
+
+using detail::v1::compute_input_impl;
+using detail::v1::compute_result_impl;
+
+namespace v1 {
+
+template
+compute_input::compute_input(const table& data) : impl_(new compute_input_impl(data)) {}
+
+template
+const table& compute_input::get_data() const {
+ return impl_->data;
+}
+
+template
+void compute_input::set_data_impl(const table& value) {
+ impl_->data = value;
+}
+
+template
+compute_result::compute_result() : impl_(new compute_result_impl{}) {}
+
+template
+const table& compute_result::get_embedding() const {
+ using msg = dal::detail::error_messages;
+ if (!get_result_options().test(result_options::embedding)) {
+ throw domain_error(msg::this_result_is_not_enabled_via_result_options());
+ }
+ return impl_->embedding;
+}
+
+template
+void compute_result::set_embedding_impl(const table& value) {
+ using msg = dal::detail::error_messages;
+ if (!get_result_options().test(result_options::embedding)) {
+ throw domain_error(msg::this_result_is_not_enabled_via_result_options());
+ }
+ impl_->embedding = value;
+}
+
+template
+const table& compute_result::get_eigen_values() const {
+ using msg = dal::detail::error_messages;
+ if (!get_result_options().test(result_options::eigen_values)) {
+ throw domain_error(msg::this_result_is_not_enabled_via_result_options());
+ }
+ return impl_->eigen_values;
+}
+
+template
+void compute_result::set_eigen_values_impl(const table& value) {
+ using msg = dal::detail::error_messages;
+ if (!get_result_options().test(result_options::eigen_values)) {
+ throw domain_error(msg::this_result_is_not_enabled_via_result_options());
+ }
+ impl_->eigen_values = value;
+}
+
+template
+const result_option_id& compute_result::get_result_options() const {
+ return impl_->options;
+}
+
+template
+void compute_result::set_result_options_impl(const result_option_id& value) {
+ impl_->options = value;
+}
+
+template class ONEDAL_EXPORT compute_input;
+template class ONEDAL_EXPORT compute_result;
+
+} // namespace v1
+} // namespace oneapi::dal::spectral_embedding
diff --git a/cpp/oneapi/dal/algo/spectral_embedding/compute_types.hpp b/cpp/oneapi/dal/algo/spectral_embedding/compute_types.hpp
new file mode 100644
index 00000000000..010b0b84760
--- /dev/null
+++ b/cpp/oneapi/dal/algo/spectral_embedding/compute_types.hpp
@@ -0,0 +1,124 @@
+/*******************************************************************************
+* Copyright contributors to the oneDAL project
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#pragma once
+
+#include "oneapi/dal/algo/spectral_embedding/common.hpp"
+
+namespace oneapi::dal::spectral_embedding {
+
+namespace detail {
+namespace v1 {
+template
+class compute_input_impl;
+
+template
+class compute_result_impl;
+
+} // namespace v1
+
+using v1::compute_input_impl;
+using v1::compute_result_impl;
+
+} // namespace detail
+
+namespace v1 {
+
+/// @tparam Task Tag-type that specifies the type of the problem to solve. Can
+/// be :expr:`task::compute`.
+template
+class compute_input : public base {
+ static_assert(detail::is_valid_task_v);
+
+public:
+ using task_t = Task;
+
+ /// Creates a new instance of the class with the given :literal:`data`
+ compute_input(const table& data);
+
+ /// An $n \\times p$ table with the training data, where each row stores one
+ /// feature vector.
+ /// @remark default = table{}
+ const table& get_data() const;
+
+ auto& set_data(const table& value) {
+ set_data_impl(value);
+ return *this;
+ }
+
+protected:
+ void set_data_impl(const table& value);
+
+private:
+ dal::detail::pimpl> impl_;
+};
+
+/// @tparam Task Tag-type that specifies the type of the problem to solve. Can
+/// be :expr:`task::compute`.
+template
+class compute_result : public base {
+ static_assert(detail::is_valid_task_v);
+
+public:
+ using task_t = Task;
+
+ /// Creates a new instance of the class with the default property values.
+ compute_result();
+
+ /// The matrix of size $n \\times k$ with
+ /// spectral embeddings.
+ /// @remark default = table{}
+ const table& get_embedding() const;
+
+ auto& set_embedding(const table& value) {
+ set_embedding_impl(value);
+ return *this;
+ }
+
+ /// The matrix of size $n \\times 1$ with
+ /// eigen values of Laplassian matrix.
+ /// @remark default = table{}
+ const table& get_eigen_values() const;
+
+ auto& set_eigen_values(const table& value) {
+ set_eigen_values_impl(value);
+ return *this;
+ }
+
+ /// Result options that indicates availability of the properties
+ /// @remark default = default_result_options
+ const result_option_id& get_result_options() const;
+
+ auto& set_result_options(const result_option_id& value) {
+ set_result_options_impl(value);
+ return *this;
+ }
+
+protected:
+ void set_embedding_impl(const table&);
+ void set_eigen_values_impl(const table&);
+ void set_result_options_impl(const result_option_id&);
+
+private:
+ dal::detail::pimpl> impl_;
+};
+
+} // namespace v1
+
+using v1::compute_input;
+using v1::compute_result;
+
+} // namespace oneapi::dal::spectral_embedding
diff --git a/cpp/oneapi/dal/algo/spectral_embedding/detail/compute_ops.cpp b/cpp/oneapi/dal/algo/spectral_embedding/detail/compute_ops.cpp
new file mode 100644
index 00000000000..3cbec045d3b
--- /dev/null
+++ b/cpp/oneapi/dal/algo/spectral_embedding/detail/compute_ops.cpp
@@ -0,0 +1,42 @@
+/*******************************************************************************
+* Copyright contributors to the oneDAL project
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "oneapi/dal/algo/spectral_embedding/detail/compute_ops.hpp"
+#include "oneapi/dal/algo/spectral_embedding/backend/cpu/compute_kernel.hpp"
+#include "oneapi/dal/backend/dispatcher.hpp"
+
+namespace oneapi::dal::spectral_embedding::detail {
+namespace v1 {
+
+template
+struct compute_ops_dispatcher {
+ compute_result operator()(const Policy& policy,
+ const descriptor_base& desc,
+ const compute_input& input) const {
+ using kernel_dispatcher_t = dal::backend::kernel_dispatcher< //
+ KERNEL_SINGLE_NODE_CPU(backend::compute_kernel_cpu)>;
+ return kernel_dispatcher_t()(policy, desc, input);
+ }
+};
+
+#define INSTANTIATE(F, M, T) \
+ template struct ONEDAL_EXPORT compute_ops_dispatcher;
+
+INSTANTIATE(float, method::dense_batch, task::compute)
+INSTANTIATE(double, method::dense_batch, task::compute)
+
+} // namespace v1
+} // namespace oneapi::dal::spectral_embedding::detail
diff --git a/cpp/oneapi/dal/algo/spectral_embedding/detail/compute_ops.hpp b/cpp/oneapi/dal/algo/spectral_embedding/detail/compute_ops.hpp
new file mode 100644
index 00000000000..8b988c2fb7c
--- /dev/null
+++ b/cpp/oneapi/dal/algo/spectral_embedding/detail/compute_ops.hpp
@@ -0,0 +1,78 @@
+/*******************************************************************************
+* Copyright contributors to the oneDAL project
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#pragma once
+
+#include "oneapi/dal/algo/spectral_embedding/compute_types.hpp"
+#include "oneapi/dal/detail/error_messages.hpp"
+
+namespace oneapi::dal::spectral_embedding::detail {
+namespace v1 {
+
+template
+struct compute_ops_dispatcher {
+ compute_result operator()(const Context&,
+ const descriptor_base&,
+ const compute_input&) const;
+};
+
+template
+struct compute_ops {
+ using float_t = typename Descriptor::float_t;
+ using method_t = typename Descriptor::method_t;
+ using task_t = typename Descriptor::task_t;
+ using input_t = compute_input;
+ using result_t = compute_result;
+ using descriptor_base_t = descriptor_base;
+
+ void check_preconditions(const Descriptor& params, const input_t& input) const {
+ using msg = dal::detail::error_messages;
+
+ if (!input.get_data().has_data()) {
+ throw domain_error(msg::input_data_is_empty());
+ }
+ }
+
+ void check_postconditions(const Descriptor& params,
+ const input_t& input,
+ const result_t& result) const {
+ using msg = dal::detail::error_messages;
+ std::int64_t n = input.get_data().get_row_count();
+ if (result.get_result_options().test(result_options::embedding)) {
+ if (!result.get_embedding().has_data()) {
+ throw domain_error(msg::value_is_not_provided()); // TODO: update error message!
+ }
+ if (result.get_embedding().get_row_count() != n) {
+ throw domain_error(msg::incorrect_output_table_size());
+ }
+ }
+ }
+
+ template
+ auto operator()(const Context& ctx, const Descriptor& desc, const input_t& input) const {
+ check_preconditions(desc, input);
+ const auto result =
+ compute_ops_dispatcher()(ctx, desc, input);
+ check_postconditions(desc, input, result);
+ return result;
+ }
+};
+
+} // namespace v1
+
+using v1::compute_ops;
+
+} // namespace oneapi::dal::spectral_embedding::detail
diff --git a/cpp/oneapi/dal/algo/spectral_embedding/detail/compute_ops_dpc.cpp b/cpp/oneapi/dal/algo/spectral_embedding/detail/compute_ops_dpc.cpp
new file mode 100644
index 00000000000..420fa869b69
--- /dev/null
+++ b/cpp/oneapi/dal/algo/spectral_embedding/detail/compute_ops_dpc.cpp
@@ -0,0 +1,45 @@
+/*******************************************************************************
+* Copyright contributors to the oneDAL project
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "oneapi/dal/algo/spectral_embedding/backend/cpu/compute_kernel.hpp"
+#include "oneapi/dal/algo/spectral_embedding/backend/gpu/compute_kernel.hpp"
+#include "oneapi/dal/algo/spectral_embedding/detail/compute_ops.hpp"
+#include "oneapi/dal/backend/dispatcher.hpp"
+
+namespace oneapi::dal::spectral_embedding::detail {
+namespace v1 {
+
+template
+struct compute_ops_dispatcher {
+ compute_result operator()(const Policy& policy,
+ const descriptor_base& params,
+ const compute_input& input) const {
+ using kernel_dispatcher_t = dal::backend::kernel_dispatcher<
+ KERNEL_SINGLE_NODE_CPU(backend::compute_kernel_cpu),
+ KERNEL_SINGLE_NODE_GPU(backend::compute_kernel_gpu)>;
+ return kernel_dispatcher_t{}(policy, params, input);
+ }
+};
+
+#define INSTANTIATE(F, M, T) \
+ template struct ONEDAL_EXPORT \
+ compute_ops_dispatcher;
+
+INSTANTIATE(float, method::dense_batch, task::compute)
+INSTANTIATE(double, method::dense_batch, task::compute)
+
+} // namespace v1
+} // namespace oneapi::dal::spectral_embedding::detail
diff --git a/cpp/oneapi/dal/algo/spectral_embedding/test/batch.cpp b/cpp/oneapi/dal/algo/spectral_embedding/test/batch.cpp
new file mode 100644
index 00000000000..82353d942e0
--- /dev/null
+++ b/cpp/oneapi/dal/algo/spectral_embedding/test/batch.cpp
@@ -0,0 +1,55 @@
+/*******************************************************************************
+* Copyright contributors to the oneDAL project
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "oneapi/dal/table/homogen.hpp"
+#include "oneapi/dal/table/detail/table_builder.hpp"
+#include "oneapi/dal/table/row_accessor.hpp"
+
+#include "oneapi/dal/test/engine/common.hpp"
+#include "oneapi/dal/test/engine/fixtures.hpp"
+#include "oneapi/dal/test/engine/dataframe.hpp"
+#include "oneapi/dal/test/engine/math.hpp"
+#include "oneapi/dal/algo/spectral_embedding/test/fixture.hpp"
+
+namespace oneapi::dal::spectral_embedding::test {
+
+namespace te = dal::test::engine;
+namespace de = dal::detail;
+namespace sp_emb = oneapi::dal::spectral_embedding;
+
+template
+class spectral_embedding_batch_test
+ : public spectral_embedding_test> {
+public:
+ using base_t = spectral_embedding_test>;
+
+ void gen_dimensions() {
+ this->n_ = GENERATE(8);
+ this->p_ = GENERATE(3);
+ }
+};
+
+TEMPLATE_LIST_TEST_M(spectral_embedding_batch_test,
+ "spectral_embedding gold test",
+ "[spectral embedding][integration][cpu]",
+ spectral_embedding_types) {
+ SKIP_IF(this->not_float64_friendly());
+ SKIP_IF(this->get_policy().is_gpu());
+
+ this->test_gold_input();
+}
+
+} // namespace oneapi::dal::spectral_embedding::test
diff --git a/cpp/oneapi/dal/algo/spectral_embedding/test/fixture.hpp b/cpp/oneapi/dal/algo/spectral_embedding/test/fixture.hpp
new file mode 100644
index 00000000000..54e490c4515
--- /dev/null
+++ b/cpp/oneapi/dal/algo/spectral_embedding/test/fixture.hpp
@@ -0,0 +1,131 @@
+/*******************************************************************************
+* Copyright contributors to the oneDAL project
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#pragma once
+
+#include "oneapi/dal/algo/spectral_embedding/compute.hpp"
+
+#include "oneapi/dal/test/engine/common.hpp"
+#include "oneapi/dal/test/engine/fixtures.hpp"
+#include "oneapi/dal/test/engine/dataframe.hpp"
+#include "oneapi/dal/test/engine/math.hpp"
+#include "oneapi/dal/detail/debug.hpp"
+#include
+
+namespace oneapi::dal::spectral_embedding::test {
+
+namespace te = dal::test::engine;
+namespace de = dal::detail;
+namespace sp_emb = oneapi::dal::spectral_embedding;
+
+using dal::detail::operator<<;
+
+template
+class spectral_embedding_test : public te::crtp_algo_fixture {
+public:
+ using Float = std::tuple_element_t<0, TestType>;
+ using Method = std::tuple_element_t<1, TestType>;
+ using input_t = sp_emb::compute_input<>;
+ using result_t = sp_emb::compute_result<>;
+ using descriptor_t = sp_emb::descriptor;
+
+ auto get_descriptor(std::int64_t component_count,
+ std::int64_t neighbor_count,
+ sp_emb::result_option_id compute_mode) const {
+ return descriptor_t()
+ .set_component_count(component_count)
+ .set_neighbor_count(neighbor_count)
+ .set_result_options(compute_mode);
+ }
+
+ void gen_input() {
+ std::mt19937 rnd(2007 + n_ + p_ + n_ * p_);
+ const te::dataframe data_df =
+ GENERATE_DATAFRAME(te::dataframe_builder{ n_, p_ }.fill_normal(-0.5, 0.5, 7777));
+ data_ = data_df.get_table(this->get_policy(), this->get_homogen_table_id());
+ }
+
+ void test_gold_input(Float tol = 1e-5) {
+ constexpr std::int64_t n = 8;
+ constexpr std::int64_t p = 4;
+ constexpr std::int64_t neighbor_count = 5;
+ constexpr std::int64_t component_count = 4;
+
+ constexpr Float data[n * p] = { 0.49671415, -0.1382643, 0.64768854, 1.52302986,
+ -0.23415337, -0.23413696, 1.57921282, 0.76743473,
+ -0.46947439, 0.54256004, -0.46341769, -0.46572975,
+ 0.24196227, -1.91328024, -1.72491783, -0.56228753,
+ -1.01283112, 0.31424733, -0.90802408, -1.4123037,
+ 1.46564877, -0.2257763, 0.0675282, -1.42474819,
+ -0.54438272, 0.11092259, -1.15099358, 0.37569802,
+ -0.60063869, -0.29169375, -0.60170661, 1.85227818 };
+
+ constexpr Float gth_embedding[n * component_count] = {
+ -0.353553391, 0.442842965, 0.190005876, 0.705830111, -0.353553391, 0.604392576,
+ -0.247517958, -0.595235173, -0.353553391, -0.391745507, 0.0443633719, -0.150208165,
+ -0.353553391, -0.142548722, 0.0125222995, -0.0318482841, -0.353553391, -0.499390711,
+ -0.20194266, -0.000639679859, -0.353553391, 0.00809834849, -0.683462258, 0.273398265,
+ -0.353553391, -0.0977843445, 0.449358299, 0.0195905172, -0.353553391, 0.0761353959,
+ 0.436673029, -0.220887591
+ };
+
+ constexpr Float gth_eigen_vals[n] = { 0, 3.32674524, 4.70361338, 5.26372220,
+ 5.69343808, 6.63074948, 6.80173994, 7.57999167 };
+
+ auto desc = get_descriptor(
+ component_count,
+ neighbor_count,
+ sp_emb::result_options::embedding | sp_emb::result_options::eigen_values);
+
+ table data_ = homogen_table::wrap(data, n, p);
+
+ INFO("run compute");
+ auto compute_result = this->compute(desc, data_);
+ auto embedding = compute_result.get_embedding();
+ // std::cout << "Output" << std::endl;
+ // std::cout << embedding << std::endl;
+
+ array emb_arr = row_accessor(embedding).pull({ 0, -1 });
+ for (int j = 0; j < component_count; ++j) {
+ Float diff = 0, diff_rev = 0;
+ for (int i = 0; i < n; ++i) {
+ Float val = emb_arr[i * component_count + j];
+ Float gth_val = gth_embedding[i * component_count + j];
+ diff = std::max(diff, std::abs(val - gth_val));
+ diff_rev = std::max(diff_rev, std::abs(val + gth_val));
+ }
+ REQUIRE((diff < tol || diff_rev < tol));
+ }
+
+ auto eigen_values = compute_result.get_eigen_values();
+ // std::cout << "Eigen values:" << std::endl;
+ // std::cout << eigen_values << std::endl;
+
+ array eig_val_arr = row_accessor(eigen_values).pull({ 0, -1 });
+ for (int i = 0; i < n; ++i) {
+ REQUIRE(std::abs(eig_val_arr[i] - gth_eigen_vals[i]) < tol);
+ }
+ }
+
+protected:
+ std::int64_t n_;
+ std::int64_t p_;
+ table data_;
+};
+
+using spectral_embedding_types = COMBINE_TYPES((float, double), (sp_emb::method::dense_batch));
+
+} // namespace oneapi::dal::spectral_embedding::test
diff --git a/cpp/oneapi/dal/detail/error_messages.cpp b/cpp/oneapi/dal/detail/error_messages.cpp
index 20ce68e0ef2..7fbdab3bb2a 100644
--- a/cpp/oneapi/dal/detail/error_messages.cpp
+++ b/cpp/oneapi/dal/detail/error_messages.cpp
@@ -262,6 +262,10 @@ MSG(nothing_to_compute, "Invalid combination of optional results: nothing to com
MSG(distances_are_uninitialized, "Distances are not set as an optional result")
MSG(predecessors_are_uninitialized, "Predecessors are not set as an optional result")
+/*Spectral Embedding*/
+MSG(sp_emb_dense_batch_method_is_not_implemented_for_gpu,
+ "Spectral embedding algorithm is not implemented on GPU")
+
/* SVM */
MSG(c_leq_zero, "C is lower than or equal to zero")
MSG(cache_size_lt_zero, "Cache size is lower than zero")
diff --git a/cpp/oneapi/dal/detail/error_messages.hpp b/cpp/oneapi/dal/detail/error_messages.hpp
index dbb9c8ba6ec..60e60a925f7 100644
--- a/cpp/oneapi/dal/detail/error_messages.hpp
+++ b/cpp/oneapi/dal/detail/error_messages.hpp
@@ -291,6 +291,9 @@ class ONEDAL_EXPORT error_messages {
MSG(distances_are_uninitialized);
MSG(predecessors_are_uninitialized);
+ /*Spectral Embedding*/
+ MSG(sp_emb_dense_batch_method_is_not_implemented_for_gpu);
+
/* SVM */
MSG(c_leq_zero);
MSG(cache_size_lt_zero);
diff --git a/examples/oneapi/cpp/BUILD b/examples/oneapi/cpp/BUILD
index a90cf71b7df..dc99e223346 100644
--- a/examples/oneapi/cpp/BUILD
+++ b/examples/oneapi/cpp/BUILD
@@ -76,6 +76,19 @@ dal_example_suite(
],
)
+dal_example_suite(
+ name = "spectral_clustering",
+ compile_as = [ "c++" ],
+ srcs = glob(["source/spectral_clustering/*.cpp"]),
+ dal_deps = [
+ "@onedal//cpp/oneapi/dal/algo:kmeans",
+ "@onedal//cpp/oneapi/dal/algo:kmeans_init",
+ "@onedal//cpp/oneapi/dal/algo:spectral_embedding",
+ ],
+ data = _DATA_DEPS,
+ extra_deps = _TEST_DEPS,
+)
+
dal_algo_example_suite(
algos = [
"basic_statistics",
diff --git a/examples/oneapi/cpp/source/spectral_clustering/spectral_clustering_pipeline.cpp b/examples/oneapi/cpp/source/spectral_clustering/spectral_clustering_pipeline.cpp
new file mode 100644
index 00000000000..d14a29776cc
--- /dev/null
+++ b/examples/oneapi/cpp/source/spectral_clustering/spectral_clustering_pipeline.cpp
@@ -0,0 +1,90 @@
+/*******************************************************************************
+* Copyright contributors to the oneDAL project
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "example_util/utils.hpp"
+#include "oneapi/dal/algo/kmeans.hpp"
+#include "oneapi/dal/algo/kmeans_init.hpp"
+#include "oneapi/dal/algo/spectral_embedding.hpp"
+#include "oneapi/dal/io/csv.hpp"
+#include
+#include
+
+namespace dal = oneapi::dal;
+
+int main(int argc, char const *argv[]) {
+ double p = 0.01; // prunning parameter
+ std::int64_t num_spks = 8; // dimension of spectral embeddings
+
+ std::int64_t cluster_count = num_spks; // number of clusters
+ std::int64_t max_iteration_count = 300; // max iterations number for K-Means
+ std::int64_t n_init = 20; // number of K-means++ iterations
+ double accuracy_threshold = 1e-4; // threshold for early stop in K-Means
+
+ const auto voice_data_file_name =
+ get_data_path("covcormoments_dense.csv"); // Dataset with original features
+
+ std::cout << voice_data_file_name << std::endl;
+
+ const auto x_train = dal::read(dal::csv::data_source{ voice_data_file_name });
+
+ std::int64_t m = x_train.get_row_count();
+ std::int64_t n_neighbors;
+
+ if (m < 1000) {
+ n_neighbors = std::min((std::int64_t)10, m - 2);
+ }
+ else {
+ n_neighbors = (std::int64_t)(p * m);
+ }
+
+ const auto spectral_embedding_desc =
+ dal::spectral_embedding::descriptor<>()
+ .set_neighbor_count(n_neighbors)
+ .set_component_count(num_spks)
+ .set_result_options(dal::spectral_embedding::result_options::embedding |
+ dal::spectral_embedding::result_options::eigen_values);
+
+ const auto spectral_embedding_result = dal::compute(spectral_embedding_desc, x_train);
+
+ const auto spectral_embeddings =
+ spectral_embedding_result.get_embedding(); // Matrix with spectral embeddings m * num_spks
+
+ std::cout << "Spectral embeddings:\n" << spectral_embeddings << std::endl;
+
+ std::cout << "Eigen values:\n" << spectral_embedding_result.get_eigen_values() << std::endl;
+
+ const auto kmeans_init_desc =
+ dal::kmeans_init::descriptor()
+ .set_cluster_count(cluster_count)
+ .set_local_trials_count(n_init);
+
+ const auto kmeans_init_result = dal::compute(kmeans_init_desc, spectral_embeddings);
+
+ const auto initial_centroids = kmeans_init_result.get_centroids();
+
+ const auto kmeans_desc = dal::kmeans::descriptor<>()
+ .set_cluster_count(cluster_count)
+ .set_max_iteration_count(max_iteration_count)
+ .set_accuracy_threshold(accuracy_threshold);
+
+ const auto spectral_clustering_result =
+ dal::train(kmeans_desc, spectral_embeddings, initial_centroids);
+
+ std::cout << "Responses:\n" << spectral_clustering_result.get_responses() << std::endl;
+ std::cout << "Centroids:\n"
+ << spectral_clustering_result.get_model().get_centroids() << std::endl;
+ return 0;
+}
diff --git a/makefile.lst b/makefile.lst
index 92dc52ff521..ccc82a8a579 100755
--- a/makefile.lst
+++ b/makefile.lst
@@ -24,7 +24,7 @@ CORE.ALGORITHMS.CUSTOM.AVAILABLE := low_order_moments quantiles covariance cosdi
dtrees/gbt dtrees/forest linear_regression ridge_regression naivebayes stump adaboost brownboost \
logitboost svm multiclassclassifier k_nearest_neighbors logistic_regression implicit_als \
coordinate_descent jaccard triangle_counting shortest_paths subgraph_isomorphism connected_components \
- louvain tsne
+ louvain tsne spectral_embedding
classifier += classifier/inner
low_order_moments +=
@@ -68,6 +68,7 @@ implicit_als += engines distributions
engines += engines/mt19937 engines/mcg59 engines/mt2203
distributions += distributions/bernoulli distributions/normal distributions/uniform
tsne +=
+spectral_embedding += cosdistance
CORE.ALGORITHMS.FULL := \
adaboost \
@@ -141,7 +142,8 @@ CORE.ALGORITHMS.FULL := \
svd \
svm \
weak_learner/inner \
- tsne
+ tsne \
+ spectral_embedding
CORE.ALGORITHMS := $(if $(CORE.ALGORITHMS.CUSTOM), $(CORE.ALGORITHMS.CUSTOM), $(CORE.ALGORITHMS.FULL))
CORE.ALGORITHMS := $(sort $(foreach alg,$(CORE.ALGORITHMS),$(foreach alg1,$($(alg)),$(foreach alg2,$($(alg1)),$($(alg2)) $(alg2)) $(alg1)) $(alg)))
@@ -216,6 +218,7 @@ ONEAPI.ALGOS.pca := CORE.pca
ONEAPI.ALGOS.polynomial_kernel := CORE.kernel_function
ONEAPI.ALGOS.sigmoid_kernel := CORE.kernel_function
ONEAPI.ALGOS.rbf_kernel := CORE.kernel_function
+ONEAPI.ALGOS.spectral_embedding := CORE.spectral_embedding
ONEAPI.ALGOS.svm := CORE.svm
# List of algorithms in oneAPI part
@@ -244,6 +247,7 @@ ONEAPI.ALGOS := \
polynomial_kernel \
sigmoid_kernel \
rbf_kernel \
+ spectral_embedding \
svm \
jaccard \
triangle_counting \