diff --git a/blas/tpls/KokkosBlas2_gemv_tpl_spec_decl.hpp b/blas/tpls/KokkosBlas2_gemv_tpl_spec_decl.hpp index 2ace065808..2f7444252d 100644 --- a/blas/tpls/KokkosBlas2_gemv_tpl_spec_decl.hpp +++ b/blas/tpls/KokkosBlas2_gemv_tpl_spec_decl.hpp @@ -777,17 +777,6 @@ KOKKOSBLAS2_CGEMV_ROCBLAS(Kokkos::LayoutRight, Kokkos::HIPSpace, false) namespace KokkosBlas { namespace Impl { -inline oneapi::mkl::transpose mode_kk_to_onemkl(char mode_kk) { - switch (toupper(mode_kk)) { - case 'N': return oneapi::mkl::transpose::nontrans; - case 'T': return oneapi::mkl::transpose::trans; - case 'C': return oneapi::mkl::transpose::conjtrans; - default:; - } - throw std::invalid_argument( - "Invalid mode for oneMKL (should be one of N, T, C)"); -} - template struct kokkos_to_std_type_map { using type = T; @@ -829,7 +818,7 @@ struct kokkos_to_std_type_map { bool row_major = std::is_same::value; \ const std::int64_t M = A.extent(0); \ const std::int64_t N = A.extent(1); \ - oneapi::mkl::transpose trans = mode_kk_to_onemkl(kk_trans[0]); \ + oneapi::mkl::transpose trans = trans_mode_kk_to_onemkl(kk_trans[0]); \ const std::int64_t LDA = row_major ? A.stride(0) : A.stride(1); \ std::string label = "KokkosBlas::gemv[TPL_ONEMKL," + \ Kokkos::ArithTraits::name() + "]"; \ diff --git a/blas/tpls/KokkosBlas3_gemm_tpl_spec_avail.hpp b/blas/tpls/KokkosBlas3_gemm_tpl_spec_avail.hpp index 8e96898b10..fc866d9b70 100644 --- a/blas/tpls/KokkosBlas3_gemm_tpl_spec_avail.hpp +++ b/blas/tpls/KokkosBlas3_gemm_tpl_spec_avail.hpp @@ -182,6 +182,46 @@ KOKKOSBLAS3_GEMM_TPL_SPEC_AVAIL_ROCBLAS(Kokkos::complex, Kokkos::LayoutRight, Kokkos::HIPSpace) #endif + +#if defined(KOKKOSKERNELS_ENABLE_TPL_MKL) && defined(KOKKOS_ENABLE_SYCL) + +#define KOKKOSBLAS3_GEMM_TPL_SPEC_AVAIL_MKL(SCALAR, LAYOUT, MEMSPACE) \ + template <> \ + struct gemm_tpl_spec_avail< \ + Kokkos::Experimental::SYCL, \ + Kokkos::View, \ + Kokkos::MemoryTraits >, \ + Kokkos::View, \ + Kokkos::MemoryTraits >, \ + Kokkos::View, \ + Kokkos::MemoryTraits > > { \ + enum : bool { value = true }; \ + }; + +KOKKOSBLAS3_GEMM_TPL_SPEC_AVAIL_MKL(double, Kokkos::LayoutLeft, + Kokkos::Experimental::SYCLDeviceUSMSpace) +KOKKOSBLAS3_GEMM_TPL_SPEC_AVAIL_MKL(float, Kokkos::LayoutLeft, + Kokkos::Experimental::SYCLDeviceUSMSpace) +KOKKOSBLAS3_GEMM_TPL_SPEC_AVAIL_MKL(Kokkos::complex, Kokkos::LayoutLeft, + Kokkos::Experimental::SYCLDeviceUSMSpace) +KOKKOSBLAS3_GEMM_TPL_SPEC_AVAIL_MKL(Kokkos::complex, Kokkos::LayoutLeft, + Kokkos::Experimental::SYCLDeviceUSMSpace) + +KOKKOSBLAS3_GEMM_TPL_SPEC_AVAIL_MKL(double, Kokkos::LayoutRight, + Kokkos::Experimental::SYCLDeviceUSMSpace) +KOKKOSBLAS3_GEMM_TPL_SPEC_AVAIL_MKL(float, Kokkos::LayoutRight, + Kokkos::Experimental::SYCLDeviceUSMSpace) +KOKKOSBLAS3_GEMM_TPL_SPEC_AVAIL_MKL(Kokkos::complex, + Kokkos::LayoutRight, + Kokkos::Experimental::SYCLDeviceUSMSpace) +KOKKOSBLAS3_GEMM_TPL_SPEC_AVAIL_MKL(Kokkos::complex, Kokkos::LayoutRight, + Kokkos::Experimental::SYCLDeviceUSMSpace) + +#endif + } // namespace Impl } // namespace KokkosBlas diff --git a/blas/tpls/KokkosBlas3_gemm_tpl_spec_decl.hpp b/blas/tpls/KokkosBlas3_gemm_tpl_spec_decl.hpp index 66177e28a6..d6d11897a3 100644 --- a/blas/tpls/KokkosBlas3_gemm_tpl_spec_decl.hpp +++ b/blas/tpls/KokkosBlas3_gemm_tpl_spec_decl.hpp @@ -501,4 +501,146 @@ KOKKOSBLAS3_CGEMM_ROCBLAS(Kokkos::LayoutRight, Kokkos::HIPSpace, false) } // namespace KokkosBlas #endif // KOKKOSKERNELS_ENABLE_TPL_ROCBLAS +#if defined(KOKKOSKERNELS_ENABLE_TPL_MKL) && defined(KOKKOS_ENABLE_SYCL) +#include +#include + +namespace KokkosBlas::Impl { + +/*! +SCALAR_TYPE is the Kokkos Kernels type +TPL_SCALAR_TYPE is the type MKL accents for SCALAR_TYPE +*/ +#define KOKKOSBLAS3_XGEMM_MKL(SCALAR_TYPE, TPL_SCALAR_TYPE, LAYOUT, MEM_SPACE, \ + ETI_SPEC_AVAIL) \ + template <> \ + struct GEMM< \ + Kokkos::Experimental::SYCL, \ + Kokkos::View, \ + Kokkos::MemoryTraits >, \ + Kokkos::View, \ + Kokkos::MemoryTraits >, \ + Kokkos::View, \ + Kokkos::MemoryTraits >, \ + true, ETI_SPEC_AVAIL> { \ + typedef SCALAR_TYPE SCALAR; \ + typedef Kokkos::View< \ + const SCALAR**, LAYOUT, \ + Kokkos::Device, \ + Kokkos::MemoryTraits > \ + AViewType; \ + typedef Kokkos::View< \ + const SCALAR**, LAYOUT, \ + Kokkos::Device, \ + Kokkos::MemoryTraits > \ + BViewType; \ + typedef Kokkos::View< \ + SCALAR**, LAYOUT, \ + Kokkos::Device, \ + Kokkos::MemoryTraits > \ + CViewType; \ + \ + static void gemm(const typename CViewType::execution_space& space, \ + const char transA[], const char transB[], \ + typename AViewType::const_value_type& alpha, \ + const AViewType& A, const BViewType& B, \ + typename CViewType::const_value_type& beta, \ + const CViewType& C) { \ + Kokkos::Profiling::pushRegion("KokkosBlas::gemm[TPL_MKL," #SCALAR_TYPE \ + "]"); \ + \ + const bool A_t = (transA[0] != 'N') && (transA[0] != 'n'); \ + const int64_t M = static_cast(C.extent(0)); \ + const int64_t N = static_cast(C.extent(1)); \ + const int64_t K = static_cast(A.extent(A_t ? 0 : 1)); \ + \ + constexpr bool is_lr = std::is_same::value; \ + \ + const int64_t ast = is_lr ? A.stride(0) : A.stride(1); \ + const int64_t lda = ast == 0 ? 1 : ast; \ + const int64_t bst = is_lr ? B.stride(0) : B.stride(1); \ + const int64_t ldb = bst == 0 ? 1 : bst; \ + const int64_t cst = is_lr ? C.stride(0) : C.stride(1); \ + const int64_t ldc = cst == 0 ? 1 : cst; \ + \ + oneapi::mkl::transpose transa = trans_mode_kk_to_onemkl(transA[0]); \ + oneapi::mkl::transpose transb = trans_mode_kk_to_onemkl(transB[0]); \ + oneapi::mkl::blas::compute_mode mode = \ + oneapi::mkl::blas::compute_mode::standard; \ + \ + if constexpr (!is_lr) { \ + oneapi::mkl::blas::column_major::gemm( \ + space.sycl_queue(), transa, transb, M, N, K, alpha, \ + reinterpret_cast(A.data()), lda, \ + reinterpret_cast(B.data()), ldb, beta, \ + reinterpret_cast(C.data()), ldc, mode); \ + } else { \ + oneapi::mkl::blas::row_major::gemm( \ + space.sycl_queue(), transa, transb, M, N, K, alpha, \ + reinterpret_cast(A.data()), lda, \ + reinterpret_cast(B.data()), ldb, beta, \ + reinterpret_cast(C.data()), ldc, mode); \ + } \ + \ + Kokkos::Profiling::popRegion(); \ + } \ + }; + +#define KOKKOSBLAS3_DGEMM_MKL(LAYOUT, MEM_SPACE, ETI_SPEC_AVAIL) \ + KOKKOSBLAS3_XGEMM_MKL(double, double, LAYOUT, MEM_SPACE, ETI_SPEC_AVAIL) + +#define KOKKOSBLAS3_SGEMM_MKL(LAYOUT, MEM_SPACE, ETI_SPEC_AVAIL) \ + KOKKOSBLAS3_XGEMM_MKL(float, float, LAYOUT, MEM_SPACE, ETI_SPEC_AVAIL) + +#define KOKKOSBLAS3_ZGEMM_MKL(LAYOUT, MEM_SPACE, ETI_SPEC_AVAIL) \ + KOKKOSBLAS3_XGEMM_MKL(Kokkos::complex, std::complex, LAYOUT, \ + MEM_SPACE, ETI_SPEC_AVAIL) + +#define KOKKOSBLAS3_CGEMM_MKL(LAYOUT, MEM_SPACE, ETI_SPEC_AVAIL) \ + KOKKOSBLAS3_XGEMM_MKL(Kokkos::complex, std::complex, LAYOUT, \ + MEM_SPACE, ETI_SPEC_AVAIL) + +// ETI_SPEC_AVAIL is both false and true here, because we want to use +// MKL regardless of whether ETI is available. +KOKKOSBLAS3_DGEMM_MKL(Kokkos::LayoutLeft, + Kokkos::Experimental::SYCLDeviceUSMSpace, true) +KOKKOSBLAS3_DGEMM_MKL(Kokkos::LayoutLeft, + Kokkos::Experimental::SYCLDeviceUSMSpace, false) +KOKKOSBLAS3_DGEMM_MKL(Kokkos::LayoutRight, + Kokkos::Experimental::SYCLDeviceUSMSpace, true) +KOKKOSBLAS3_DGEMM_MKL(Kokkos::LayoutRight, + Kokkos::Experimental::SYCLDeviceUSMSpace, false) + +KOKKOSBLAS3_SGEMM_MKL(Kokkos::LayoutLeft, + Kokkos::Experimental::SYCLDeviceUSMSpace, true) +KOKKOSBLAS3_SGEMM_MKL(Kokkos::LayoutLeft, + Kokkos::Experimental::SYCLDeviceUSMSpace, false) +KOKKOSBLAS3_SGEMM_MKL(Kokkos::LayoutRight, + Kokkos::Experimental::SYCLDeviceUSMSpace, true) +KOKKOSBLAS3_SGEMM_MKL(Kokkos::LayoutRight, + Kokkos::Experimental::SYCLDeviceUSMSpace, false) + +KOKKOSBLAS3_ZGEMM_MKL(Kokkos::LayoutLeft, + Kokkos::Experimental::SYCLDeviceUSMSpace, true) +KOKKOSBLAS3_ZGEMM_MKL(Kokkos::LayoutLeft, + Kokkos::Experimental::SYCLDeviceUSMSpace, false) +KOKKOSBLAS3_ZGEMM_MKL(Kokkos::LayoutRight, + Kokkos::Experimental::SYCLDeviceUSMSpace, true) +KOKKOSBLAS3_ZGEMM_MKL(Kokkos::LayoutRight, + Kokkos::Experimental::SYCLDeviceUSMSpace, false) + +KOKKOSBLAS3_CGEMM_MKL(Kokkos::LayoutLeft, + Kokkos::Experimental::SYCLDeviceUSMSpace, true) +KOKKOSBLAS3_CGEMM_MKL(Kokkos::LayoutLeft, + Kokkos::Experimental::SYCLDeviceUSMSpace, false) +KOKKOSBLAS3_CGEMM_MKL(Kokkos::LayoutRight, + Kokkos::Experimental::SYCLDeviceUSMSpace, true) +KOKKOSBLAS3_CGEMM_MKL(Kokkos::LayoutRight, + Kokkos::Experimental::SYCLDeviceUSMSpace, false) +} // namespace KokkosBlas::Impl +#endif // KOKKOSKERNELS_ENABLE_TPL_MKL && KOKKOS_ENABLE_SYCL + #endif diff --git a/blas/tpls/KokkosBlas_tpl_spec.hpp b/blas/tpls/KokkosBlas_tpl_spec.hpp index a1eee4b69c..db1f1603d9 100644 --- a/blas/tpls/KokkosBlas_tpl_spec.hpp +++ b/blas/tpls/KokkosBlas_tpl_spec.hpp @@ -231,4 +231,30 @@ struct MagmaSingleton { } // namespace KokkosBlas #endif // KOKKOSKERNELS_ENABLE_TPL_MAGMA +#if defined(KOKKOSKERNELS_ENABLE_TPL_MKL) && defined(KOKKOS_ENABLE_SYCL) +#include +#include + +namespace KokkosBlas { +namespace Impl { + +/// \brief This function converts KK transpose mode to MKL transpose mode +inline oneapi::mkl::transpose trans_mode_kk_to_onemkl(char mode_kk) { + switch (toupper(mode_kk)) { + case 'N': return oneapi::mkl::transpose::nontrans; + case 'T': return oneapi::mkl::transpose::trans; + case 'C': return oneapi::mkl::transpose::conjtrans; + default:; + } + std::stringstream ss; + ss << "Invalid mode \"" << mode_kk + << "\" for oneMKL (should be one of N, T, C)"; + throw std::invalid_argument(ss.str()); +} + +} // namespace Impl +} // namespace KokkosBlas + +#endif // KOKKOSKERNELS_ENABLE_TPL_MKL && KOKKOS_ENABLE_SYCL + #endif // KOKKOSBLAS_TPL_SPEC_HPP_