Skip to content

Commit 6411407

Browse files
authored
Merge pull request #4 from s-Nick/fix_header_and_complex_issue
[BLAS] Fix complex header inclusion and multi_ptr cast
2 parents e44e2b6 + f367f92 commit 6411407

File tree

3 files changed

+40
-9
lines changed

3 files changed

+40
-9
lines changed

onemath/sycl/blas/include/blas_meta.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@
3030
#ifdef BLAS_ENABLE_COMPLEX
3131
#define SYCL_EXT_ONEAPI_COMPLEX
3232
#include <complex>
33-
#if __has_include(<ext/oneapi/experimental/complex/complex.hpp>)
34-
#include <ext/oneapi/experimental/complex/complex.hpp>
33+
#if __has_include(<sycl/ext/oneapi/experimental/complex/complex.hpp>)
34+
#include <sycl/ext/oneapi/experimental/complex/complex.hpp>
3535
#else
36-
#include <ext/oneapi/experimental/sycl_complex.hpp>
36+
#include <sycl/ext/oneapi/experimental/sycl_complex.hpp>
3737
#endif
3838
#endif
3939

onemath/sycl/blas/src/operations/blas3/gemm_load_store_complex.hpp

+12
Original file line numberDiff line numberDiff line change
@@ -116,12 +116,24 @@ class vec_complex {
116116
m_Data = *(Ptr + Offset * NumElements);
117117
}
118118

119+
// Load
120+
template <address_t Space, decorated_t DecorateAddress>
121+
void load(size_t Offset, const DataT *Ptr) {
122+
m_Data = *(Ptr + Offset * NumElements);
123+
}
124+
119125
// Store
120126
template <address_t Space, decorated_t DecorateAddress>
121127
void store(size_t Offset,
122128
sycl::multi_ptr<DataT, Space, DecorateAddress> Ptr) const {
123129
*(Ptr + Offset * NumElements) = m_Data;
124130
}
131+
132+
// Store
133+
template <address_t Space, decorated_t DecorateAddress>
134+
void store(size_t Offset, DataT *Ptr) const {
135+
*(Ptr + Offset * NumElements) = m_Data;
136+
}
125137
};
126138

127139
/*! @brief Partial specialization of the Packetize class dedicated to

onemath/sycl/blas/src/operations/blas3/gemm_local.hpp

+25-6
Original file line numberDiff line numberDiff line change
@@ -527,12 +527,31 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, TileType,
527527
element_t *reg, OutputPointerType out_ptr) {
528528
vector_out_t out_vec{};
529529

530-
out_vec.template load<address_t::private_space>(
531-
0, sycl::multi_ptr<const element_t, address_t::private_space>(reg));
532-
out_vec *= alpha_;
533-
534-
out_vec.template store<address_t::global_space>(
535-
0, sycl::multi_ptr<element_t, address_t::global_space>(out_ptr));
530+
// This if-statement is necessary starting from late 2024 nightly, because
531+
// an update made casting raw pointers of sycl::complex to multi_ptr
532+
// ambiguous.
533+
if constexpr (std::is_same_v<
534+
element_t,
535+
sycl::ext::oneapi::experimental::complex<float>> ||
536+
std::is_same_v<
537+
element_t,
538+
sycl::ext::oneapi::experimental::complex<double>>) {
539+
out_vec.template load<address_t::private_space,
540+
sycl::access::decorated::legacy>(0, reg);
541+
out_vec *= alpha_;
542+
543+
out_vec.template store<address_t::global_space,
544+
sycl::access::decorated::legacy>(0, out_ptr);
545+
} else {
546+
out_vec.template load<address_t::private_space,
547+
sycl::access::decorated::legacy>(
548+
0, sycl::multi_ptr<const element_t, address_t::private_space>(reg));
549+
out_vec *= alpha_;
550+
551+
out_vec.template store<address_t::global_space,
552+
sycl::access::decorated::legacy>(
553+
0, sycl::multi_ptr<element_t, address_t::global_space>(out_ptr));
554+
}
536555
}
537556
/*!
538557
* @brief Store the computed gemm result to the C matrix

0 commit comments

Comments
 (0)