diff --git a/sycl/include/sycl/accessor.hpp b/sycl/include/sycl/accessor.hpp index d804c9ca7af29..15be035ed1cf0 100644 --- a/sycl/include/sycl/accessor.hpp +++ b/sycl/include/sycl/accessor.hpp @@ -17,7 +17,7 @@ #include // for __SYCL_EXPORT #include // for is_genint, Try... #include // for associateWithH... -#include // for loop +#include // for loop #include // for OwnerLessBase #include // for PropWithDataKind #include // for PropertyListBase diff --git a/sycl/include/sycl/detail/assert.hpp b/sycl/include/sycl/detail/assert.hpp new file mode 100644 index 0000000000000..7fe7ac2dd5a01 --- /dev/null +++ b/sycl/include/sycl/detail/assert.hpp @@ -0,0 +1,18 @@ +//==---------- assert.hpp ---- SYCL assert support ------------------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include // for assert + +#ifdef __SYCL_DEVICE_ONLY__ +// TODO remove this when 'assert' is supported in device code +#define __SYCL_ASSERT(x) +#else +#define __SYCL_ASSERT(x) assert(x) +#endif // #ifdef __SYCL_DEVICE_ONLY__ diff --git a/sycl/include/sycl/detail/async_work_group_copy_ptr.hpp b/sycl/include/sycl/detail/async_work_group_copy_ptr.hpp new file mode 100644 index 0000000000000..cbeead2cbe361 --- /dev/null +++ b/sycl/include/sycl/detail/async_work_group_copy_ptr.hpp @@ -0,0 +1,101 @@ +//==-- async_work_group_copy_ptr.hpp - OpenCL pointer conversion for AWG --==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Provides convertToOpenCLGroupAsyncCopyPtr, which converts a sycl::multi_ptr +// to the decorated pointer type expected by __spirv_GroupAsyncCopy. +// +// Kept narrow so group.hpp and nd_item.hpp can use it without pulling in the +// full generic_type_traits.hpp (which transitively drags in aliases.hpp, +// bit_cast.hpp and the rest of the type-trait machinery). +// +// All dependencies (DecoratedType, access::address_space, multi_ptr) are +// already required by any header that does async_work_group_copy, so this +// header adds zero transitive cost. +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include // for DecoratedType, address_space +#include // for half_impl::BIsRepresentationT +#include +#include // for fixed_width_signed/unsigned + +#include // for std::byte +#include // for uint8_t, uint16_t +#include // for remove_const_t, is_const_v + +namespace sycl { +inline namespace _V1 { +template class __SYCL_EBO vec; +namespace ext::oneapi { +class bfloat16; +} + +namespace detail { + +// Maps a SYCL element type to the OpenCL scalar type expected by +// __spirv_GroupAsyncCopy. +template struct async_copy_elem_type { + using type = T; +}; + +template +struct async_copy_elem_type>> { + using type = + std::conditional_t, fixed_width_signed, + fixed_width_unsigned>; +}; + +#if (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0) +template <> struct async_copy_elem_type { + using type = uint8_t; +}; +#endif + +template <> struct async_copy_elem_type { + using type = half_impl::BIsRepresentationT; +}; + +template <> struct async_copy_elem_type { + // On host bfloat16 is left as-is; only rewrite to uint16_t on device, + // mirroring the behaviour of convertToOpenCLType in generic_type_traits.hpp. +#ifdef __SYCL_DEVICE_ONLY__ + using type = uint16_t; +#else + using type = ext::oneapi::bfloat16; +#endif +}; + +template struct async_copy_elem_type> { + using elem = typename async_copy_elem_type::type; +#ifdef __SYCL_DEVICE_ONLY__ + using type = std::conditional_t; +#else + using type = vec; +#endif +}; + +/// Convert a multi_ptr to the decorated raw pointer type expected by +/// __spirv_GroupAsyncCopy, rewriting the element type to its OpenCL equivalent. +template +auto convertToOpenCLGroupAsyncCopyPtr( + multi_ptr Ptr) { + using ElemNoCv = std::remove_const_t; + using OpenCLElem = typename async_copy_elem_type::type; + using ConvertedElem = std::conditional_t, + const OpenCLElem, OpenCLElem>; + using ResultType = typename DecoratedType::type *; + return reinterpret_cast(Ptr.get_decorated()); +} + +} // namespace detail +} // namespace _V1 +} // namespace sycl diff --git a/sycl/include/sycl/detail/builder.hpp b/sycl/include/sycl/detail/builder.hpp new file mode 100644 index 0000000000000..b958cac016e92 --- /dev/null +++ b/sycl/include/sycl/detail/builder.hpp @@ -0,0 +1,165 @@ +//==---------------- builder.hpp - SYCL builder helpers -------------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#pragma once + +#ifdef __SYCL_DEVICE_ONLY__ +#include +#endif + +#include +#include + +namespace sycl { +inline namespace _V1 { +template class item; +template class group; +template class h_item; +template class id; +template class nd_item; +template class range; + +namespace detail { +template T *declptr() { return static_cast(nullptr); } + +class Builder { +public: + Builder() = delete; + + template + static group + createGroup(const range &Global, const range &Local, + const range &Group, const id &Index) { + return group(Global, Local, Group, Index); + } + + template + static group createGroup(const range &Global, + const range &Local, + const id &Index) { + return group(Global, Local, Global / Local, Index); + } + + template + static ResType createSubGroupMask(BitsType Bits, size_t BitsNum) { + return ResType(Bits, BitsNum); + } + + template + static std::enable_if_t> + createItem(const range &Extent, const id &Index, + const id &Offset) { + return item(Extent, Index, Offset); + } + + template + static std::enable_if_t> + createItem(const range &Extent, const id &Index) { + return item(Extent, Index); + } + + template + static nd_item createNDItem(const item &Global, + const item &Local, + const group &Group) { + return nd_item(Global, Local, Group); + } + + template + static h_item createHItem(const item &Global, + const item &Local) { + return h_item(Global, Local); + } + + template + static h_item createHItem(const item &Global, + const item &Local, + const range &Flex) { + return h_item(Global, Local, Flex); + } + + template + static void updateItemIndex(sycl::item &Item, + const id &NextIndex) { + Item.MImpl.MIndex = NextIndex; + } + +#ifdef __SYCL_DEVICE_ONLY__ + + template + static inline constexpr bool is_valid_dimensions = (N > 0) && (N < 4); + + template static const id getElement(id *) { + static_assert(is_valid_dimensions, "invalid dimensions"); + return __spirv::initBuiltInGlobalInvocationId>(); + } + + template static const group getElement(group *) { + static_assert(is_valid_dimensions, "invalid dimensions"); + range GlobalSize{__spirv::initBuiltInGlobalSize>()}; + range LocalSize{ + __spirv::initBuiltInWorkgroupSize>()}; + range GroupRange{ + __spirv::initBuiltInNumWorkgroups>()}; + id GroupId{__spirv::initBuiltInWorkgroupId>()}; + return createGroup(GlobalSize, LocalSize, GroupRange, GroupId); + } + + template + static std::enable_if_t> getItem() { + static_assert(is_valid_dimensions, "invalid dimensions"); + id GlobalId{__spirv::initBuiltInGlobalInvocationId>()}; + range GlobalSize{__spirv::initBuiltInGlobalSize>()}; + id GlobalOffset{__spirv::initBuiltInGlobalOffset>()}; + return createItem(GlobalSize, GlobalId, GlobalOffset); + } + + template + static std::enable_if_t> getItem() { + static_assert(is_valid_dimensions, "invalid dimensions"); + id GlobalId{__spirv::initBuiltInGlobalInvocationId>()}; + range GlobalSize{__spirv::initBuiltInGlobalSize>()}; + return createItem(GlobalSize, GlobalId); + } + + template static const nd_item getElement(nd_item *) { + static_assert(is_valid_dimensions, "invalid dimensions"); + range GlobalSize{__spirv::initBuiltInGlobalSize>()}; + range LocalSize{ + __spirv::initBuiltInWorkgroupSize>()}; + range GroupRange{ + __spirv::initBuiltInNumWorkgroups>()}; + id GroupId{__spirv::initBuiltInWorkgroupId>()}; + id GlobalId{__spirv::initBuiltInGlobalInvocationId>()}; + id LocalId{__spirv::initBuiltInLocalInvocationId>()}; + id GlobalOffset{__spirv::initBuiltInGlobalOffset>()}; + group Group = + createGroup(GlobalSize, LocalSize, GroupRange, GroupId); + item GlobalItem = + createItem(GlobalSize, GlobalId, GlobalOffset); + item LocalItem = createItem(LocalSize, LocalId); + return createNDItem(GlobalItem, LocalItem, Group); + } + + template + static auto getElement(item *) + -> decltype(getItem()) { + return getItem(); + } + + template + static auto getNDItem() -> decltype(getElement(declptr>())) { + return getElement(declptr>()); + } + +#endif // __SYCL_DEVICE_ONLY__ +}; + +} // namespace detail +} // namespace _V1 +} // namespace sycl diff --git a/sycl/include/sycl/detail/builtins/builtins.hpp b/sycl/include/sycl/detail/builtins/builtins.hpp index a63882f5d9940..9b65141112bad 100644 --- a/sycl/include/sycl/detail/builtins/builtins.hpp +++ b/sycl/include/sycl/detail/builtins/builtins.hpp @@ -64,7 +64,7 @@ #pragma once #include -#include +#include #include #include #include diff --git a/sycl/include/sycl/detail/cg_types.hpp b/sycl/include/sycl/detail/cg_types.hpp index 1a505476804de..3cb3d81ca166e 100644 --- a/sycl/include/sycl/detail/cg_types.hpp +++ b/sycl/include/sycl/detail/cg_types.hpp @@ -9,20 +9,20 @@ #pragma once #include // for array -#include // for InitializedVal, NDLoop #include // for Builder #include // for HostProfilingInfo #include // for id #include // for kernel_param_kind_t +#include // for InitializedVal, NDLoop #include -#include // for group -#include // for h_item -#include // for id -#include // for item -#include // for kernel_handler -#include // for nd_item -#include // for nd_range -#include // for range, operator* +#include // for group +#include // for h_item +#include // for id +#include // for item +#include // for kernel_handler +#include // for nd_item +#include // for nd_range +#include // for range, operator* #include // for function #include // for size_t diff --git a/sycl/include/sycl/detail/common.hpp b/sycl/include/sycl/detail/common.hpp index 27e2b0560e81e..100d64309cc00 100644 --- a/sycl/include/sycl/detail/common.hpp +++ b/sycl/include/sycl/detail/common.hpp @@ -8,11 +8,11 @@ #pragma once -#include // for __SYCL_ALWAYS_INLINE +#include #include // for __SYCL_EXPORT +#include #include // for array -#include // for assert #include // for size_t #include #include // for enable_if_t @@ -166,113 +166,9 @@ class __SYCL_EXPORT tls_code_loc_t { } // namespace _V1 } // namespace sycl -#ifdef __SYCL_DEVICE_ONLY__ -// TODO remove this when 'assert' is supported in device code -#define __SYCL_ASSERT(x) -#else -#define __SYCL_ASSERT(x) assert(x) -#endif // #ifdef __SYCL_DEVICE_ONLY__ - namespace sycl { inline namespace _V1 { namespace detail { -// Produces N-dimensional object of type T whose all components are initialized -// to given integer value. -template class T> struct InitializedVal { - template static T get(); -}; - -// Specialization for a one-dimensional type. -template