Skip to content

Commit ed0f9cd

Browse files
committed
Update header for SYCL complex implementation and fix issue with
multi_ptr cast Newest compiler nightly change include path for header, this change requires to update to properly check and include sycl complex headers. multi_ptr cast as it is caused ambigous call to static_cast, this patch work-around the issue overloading load function using raw_pointers. Signed-off-by: nscipione <[email protected]>
1 parent e44e2b6 commit ed0f9cd

File tree

3 files changed

+39
-9
lines changed

3 files changed

+39
-9
lines changed

onemath/sycl/blas/include/blas_meta.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@
3030
#ifdef BLAS_ENABLE_COMPLEX
3131
#define SYCL_EXT_ONEAPI_COMPLEX
3232
#include <complex>
33-
#if __has_include(<ext/oneapi/experimental/complex/complex.hpp>)
34-
#include <ext/oneapi/experimental/complex/complex.hpp>
33+
#if __has_include(<sycl/ext/oneapi/experimental/complex/complex.hpp>)
34+
#include <sycl/ext/oneapi/experimental/complex/complex.hpp>
3535
#else
36-
#include <ext/oneapi/experimental/sycl_complex.hpp>
36+
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
3737
#endif
3838
#endif
3939

onemath/sycl/blas/src/operations/blas3/gemm_load_store_complex.hpp

+14
Original file line numberDiff line numberDiff line change
@@ -116,12 +116,26 @@ class vec_complex {
116116
m_Data = *(Ptr + Offset * NumElements);
117117
}
118118

119+
// Load
120+
template <address_t Space, decorated_t DecorateAddress>
121+
void load(size_t Offset,
122+
const DataT* Ptr) {
123+
m_Data = *(Ptr + Offset * NumElements);
124+
}
125+
119126
// Store
120127
template <address_t Space, decorated_t DecorateAddress>
121128
void store(size_t Offset,
122129
sycl::multi_ptr<DataT, Space, DecorateAddress> Ptr) const {
123130
*(Ptr + Offset * NumElements) = m_Data;
124131
}
132+
133+
// Store
134+
template <address_t Space, decorated_t DecorateAddress>
135+
void store(size_t Offset,
136+
DataT* Ptr) const {
137+
*(Ptr + Offset * NumElements) = m_Data;
138+
}
125139
};
126140

127141
/*! @brief Partial specialization of the Packetize class dedicated to

onemath/sycl/blas/src/operations/blas3/gemm_local.hpp

+22-6
Original file line numberDiff line numberDiff line change
@@ -527,12 +527,28 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, TileType,
527527
element_t *reg, OutputPointerType out_ptr) {
528528
vector_out_t out_vec{};
529529

530-
out_vec.template load<address_t::private_space>(
531-
0, sycl::multi_ptr<const element_t, address_t::private_space>(reg));
532-
out_vec *= alpha_;
533-
534-
out_vec.template store<address_t::global_space>(
535-
0, sycl::multi_ptr<element_t, address_t::global_space>(out_ptr));
530+
if constexpr (std::is_same_v<
531+
element_t,
532+
sycl::ext::oneapi::experimental::complex<float>> ||
533+
std::is_same_v<
534+
element_t,
535+
sycl::ext::oneapi::experimental::complex<double>>) {
536+
out_vec.template load<address_t::private_space,
537+
sycl::access::decorated::legacy>(0, reg);
538+
out_vec *= alpha_;
539+
540+
out_vec.template store<address_t::global_space, sycl::access::decorated::legacy>(
541+
0, out_ptr);
542+
} else {
543+
out_vec.template load<address_t::private_space,
544+
sycl::access::decorated::legacy>(
545+
0, sycl::multi_ptr<const element_t, address_t::private_space>(reg));
546+
out_vec *= alpha_;
547+
548+
out_vec.template store<address_t::global_space,
549+
sycl::access::decorated::legacy>(
550+
0, sycl::multi_ptr<element_t, address_t::global_space>(out_ptr));
551+
}
536552
}
537553
/*!
538554
* @brief Store the computed gemm result to the C matrix

0 commit comments

Comments
 (0)