Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sycl/include/sycl/accessor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#include <sycl/detail/export.hpp> // for __SYCL_EXPORT
#include <sycl/detail/generic_type_traits.hpp> // for is_genint, Try...
#include <sycl/detail/handler_proxy.hpp> // for associateWithH...
#include <sycl/detail/helpers.hpp> // for loop
#include <sycl/detail/loop.hpp> // for loop
#include <sycl/detail/owner_less_base.hpp> // for OwnerLessBase
#include <sycl/detail/property_helper.hpp> // for PropWithDataKind
#include <sycl/detail/property_list_base.hpp> // for PropertyListBase
Expand Down
18 changes: 18 additions & 0 deletions sycl/include/sycl/detail/assert.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
//==---------- assert.hpp ---- SYCL assert support ------------------------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#pragma once

#include <cassert> // for assert

#ifdef __SYCL_DEVICE_ONLY__
// TODO remove this when 'assert' is supported in device code
#define __SYCL_ASSERT(x)
#else
#define __SYCL_ASSERT(x) assert(x)
#endif // #ifdef __SYCL_DEVICE_ONLY__
101 changes: 101 additions & 0 deletions sycl/include/sycl/detail/async_work_group_copy_ptr.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
//==-- async_work_group_copy_ptr.hpp - OpenCL pointer conversion for AWG --==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Provides convertToOpenCLGroupAsyncCopyPtr, which converts a sycl::multi_ptr
// to the decorated pointer type expected by __spirv_GroupAsyncCopy.
//
// Kept narrow so group.hpp and nd_item.hpp can use it without pulling in the
// full generic_type_traits.hpp (which transitively drags in aliases.hpp,
// bit_cast.hpp and the rest of the type-trait machinery).
//
// All dependencies (DecoratedType, access::address_space, multi_ptr) are
// already required by any header that does async_work_group_copy, so this
// header adds zero transitive cost.
//
//===----------------------------------------------------------------------===//

#pragma once

#include <sycl/access/access.hpp> // for DecoratedType, address_space
#include <sycl/detail/fwd/half.hpp> // for half_impl::BIsRepresentationT
#include <sycl/detail/fwd/multi_ptr.hpp>
#include <sycl/detail/type_traits/integer_traits.hpp> // for fixed_width_signed/unsigned

#include <cstddef> // for std::byte
#include <stdint.h> // for uint8_t, uint16_t
#include <type_traits> // for remove_const_t, is_const_v

namespace sycl {
inline namespace _V1 {
template <typename DataT, int NumElements> class __SYCL_EBO vec;
namespace ext::oneapi {
class bfloat16;
}

namespace detail {

// Maps a SYCL element type to the OpenCL scalar type expected by
// __spirv_GroupAsyncCopy.
template <typename T, typename = void> struct async_copy_elem_type {
using type = T;
};

template <typename T>
struct async_copy_elem_type<T, std::enable_if_t<std::is_integral_v<T>>> {
using type =
std::conditional_t<std::is_signed_v<T>, fixed_width_signed<sizeof(T)>,
fixed_width_unsigned<sizeof(T)>>;
};

#if (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0)
template <> struct async_copy_elem_type<std::byte> {
using type = uint8_t;
};
#endif

template <> struct async_copy_elem_type<half> {
using type = half_impl::BIsRepresentationT;
};

template <> struct async_copy_elem_type<ext::oneapi::bfloat16> {
// On host bfloat16 is left as-is; only rewrite to uint16_t on device,
// mirroring the behaviour of convertToOpenCLType in generic_type_traits.hpp.
#ifdef __SYCL_DEVICE_ONLY__
using type = uint16_t;
#else
using type = ext::oneapi::bfloat16;
#endif
};

template <typename T, int N> struct async_copy_elem_type<vec<T, N>> {
using elem = typename async_copy_elem_type<T>::type;
#ifdef __SYCL_DEVICE_ONLY__
using type = std::conditional_t<N == 1, elem,
elem __attribute__((ext_vector_type(N)))>;
#else
using type = vec<elem, N>;
#endif
};

/// Convert a multi_ptr to the decorated raw pointer type expected by
/// __spirv_GroupAsyncCopy, rewriting the element type to its OpenCL equivalent.
template <typename ElementType, access::address_space Space,
access::decorated DecorateAddress>
auto convertToOpenCLGroupAsyncCopyPtr(
multi_ptr<ElementType, Space, DecorateAddress> Ptr) {
using ElemNoCv = std::remove_const_t<ElementType>;
using OpenCLElem = typename async_copy_elem_type<ElemNoCv>::type;
using ConvertedElem = std::conditional_t<std::is_const_v<ElementType>,
const OpenCLElem, OpenCLElem>;
using ResultType = typename DecoratedType<ConvertedElem, Space>::type *;
return reinterpret_cast<ResultType>(Ptr.get_decorated());
}

} // namespace detail
} // namespace _V1
} // namespace sycl
165 changes: 165 additions & 0 deletions sycl/include/sycl/detail/builder.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
//==---------------- builder.hpp - SYCL builder helpers -------------------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#pragma once

#ifdef __SYCL_DEVICE_ONLY__
#include <sycl/__spirv/spirv_vars.hpp>
#endif

#include <cstddef>
#include <type_traits>

namespace sycl {
inline namespace _V1 {
template <int Dims, bool WithOffset> class item;
template <int Dims> class group;
template <int Dims> class h_item;
template <int Dims> class id;
template <int Dims> class nd_item;
template <int Dims> class range;

namespace detail {
template <typename T> T *declptr() { return static_cast<T *>(nullptr); }

class Builder {
public:
Builder() = delete;

template <int Dims>
static group<Dims>
createGroup(const range<Dims> &Global, const range<Dims> &Local,
const range<Dims> &Group, const id<Dims> &Index) {
return group<Dims>(Global, Local, Group, Index);
}

template <int Dims>
static group<Dims> createGroup(const range<Dims> &Global,
const range<Dims> &Local,
const id<Dims> &Index) {
return group<Dims>(Global, Local, Global / Local, Index);
}

template <class ResType, typename BitsType>
static ResType createSubGroupMask(BitsType Bits, size_t BitsNum) {
return ResType(Bits, BitsNum);
}

template <int Dims, bool WithOffset>
static std::enable_if_t<WithOffset, item<Dims, WithOffset>>
createItem(const range<Dims> &Extent, const id<Dims> &Index,
const id<Dims> &Offset) {
return item<Dims, WithOffset>(Extent, Index, Offset);
}

template <int Dims, bool WithOffset>
static std::enable_if_t<!WithOffset, item<Dims, WithOffset>>
createItem(const range<Dims> &Extent, const id<Dims> &Index) {
return item<Dims, WithOffset>(Extent, Index);
}

template <int Dims>
static nd_item<Dims> createNDItem(const item<Dims, true> &Global,
const item<Dims, false> &Local,
const group<Dims> &Group) {
return nd_item<Dims>(Global, Local, Group);
}

template <int Dims>
static h_item<Dims> createHItem(const item<Dims, false> &Global,
const item<Dims, false> &Local) {
return h_item<Dims>(Global, Local);
}

template <int Dims>
static h_item<Dims> createHItem(const item<Dims, false> &Global,
const item<Dims, false> &Local,
const range<Dims> &Flex) {
return h_item<Dims>(Global, Local, Flex);
}

template <int Dims, bool WithOffset>
static void updateItemIndex(sycl::item<Dims, WithOffset> &Item,
const id<Dims> &NextIndex) {
Item.MImpl.MIndex = NextIndex;
}

#ifdef __SYCL_DEVICE_ONLY__

template <int N>
static inline constexpr bool is_valid_dimensions = (N > 0) && (N < 4);

template <int Dims> static const id<Dims> getElement(id<Dims> *) {
static_assert(is_valid_dimensions<Dims>, "invalid dimensions");
return __spirv::initBuiltInGlobalInvocationId<Dims, id<Dims>>();
}

template <int Dims> static const group<Dims> getElement(group<Dims> *) {
static_assert(is_valid_dimensions<Dims>, "invalid dimensions");
range<Dims> GlobalSize{__spirv::initBuiltInGlobalSize<Dims, range<Dims>>()};
range<Dims> LocalSize{
__spirv::initBuiltInWorkgroupSize<Dims, range<Dims>>()};
range<Dims> GroupRange{
__spirv::initBuiltInNumWorkgroups<Dims, range<Dims>>()};
id<Dims> GroupId{__spirv::initBuiltInWorkgroupId<Dims, id<Dims>>()};
return createGroup<Dims>(GlobalSize, LocalSize, GroupRange, GroupId);
}

template <int Dims, bool WithOffset>
static std::enable_if_t<WithOffset, const item<Dims, WithOffset>> getItem() {
static_assert(is_valid_dimensions<Dims>, "invalid dimensions");
id<Dims> GlobalId{__spirv::initBuiltInGlobalInvocationId<Dims, id<Dims>>()};
range<Dims> GlobalSize{__spirv::initBuiltInGlobalSize<Dims, range<Dims>>()};
id<Dims> GlobalOffset{__spirv::initBuiltInGlobalOffset<Dims, id<Dims>>()};
return createItem<Dims, true>(GlobalSize, GlobalId, GlobalOffset);
}

template <int Dims, bool WithOffset>
static std::enable_if_t<!WithOffset, const item<Dims, WithOffset>> getItem() {
static_assert(is_valid_dimensions<Dims>, "invalid dimensions");
id<Dims> GlobalId{__spirv::initBuiltInGlobalInvocationId<Dims, id<Dims>>()};
range<Dims> GlobalSize{__spirv::initBuiltInGlobalSize<Dims, range<Dims>>()};
return createItem<Dims, false>(GlobalSize, GlobalId);
}

template <int Dims> static const nd_item<Dims> getElement(nd_item<Dims> *) {
static_assert(is_valid_dimensions<Dims>, "invalid dimensions");
range<Dims> GlobalSize{__spirv::initBuiltInGlobalSize<Dims, range<Dims>>()};
range<Dims> LocalSize{
__spirv::initBuiltInWorkgroupSize<Dims, range<Dims>>()};
range<Dims> GroupRange{
__spirv::initBuiltInNumWorkgroups<Dims, range<Dims>>()};
id<Dims> GroupId{__spirv::initBuiltInWorkgroupId<Dims, id<Dims>>()};
id<Dims> GlobalId{__spirv::initBuiltInGlobalInvocationId<Dims, id<Dims>>()};
id<Dims> LocalId{__spirv::initBuiltInLocalInvocationId<Dims, id<Dims>>()};
id<Dims> GlobalOffset{__spirv::initBuiltInGlobalOffset<Dims, id<Dims>>()};
group<Dims> Group =
createGroup<Dims>(GlobalSize, LocalSize, GroupRange, GroupId);
item<Dims, true> GlobalItem =
createItem<Dims, true>(GlobalSize, GlobalId, GlobalOffset);
item<Dims, false> LocalItem = createItem<Dims, false>(LocalSize, LocalId);
return createNDItem<Dims>(GlobalItem, LocalItem, Group);
}

template <int Dims, bool WithOffset>
static auto getElement(item<Dims, WithOffset> *)
-> decltype(getItem<Dims, WithOffset>()) {
return getItem<Dims, WithOffset>();
}

template <int Dims>
static auto getNDItem() -> decltype(getElement(declptr<nd_item<Dims>>())) {
return getElement(declptr<nd_item<Dims>>());
}

#endif // __SYCL_DEVICE_ONLY__
};

} // namespace detail
} // namespace _V1
} // namespace sycl
2 changes: 1 addition & 1 deletion sycl/include/sycl/detail/builtins/builtins.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
#pragma once

#include <sycl/detail/fwd/multi_ptr.hpp>
#include <sycl/detail/helpers.hpp>
#include <sycl/detail/loop.hpp>
#include <sycl/detail/type_traits.hpp>
#include <sycl/detail/type_traits/vec_marray_traits.hpp>
#include <sycl/detail/vector_convert.hpp>
Expand Down
18 changes: 9 additions & 9 deletions sycl/include/sycl/detail/cg_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,20 @@
#pragma once

#include <sycl/detail/array.hpp> // for array
#include <sycl/detail/common.hpp> // for InitializedVal, NDLoop
#include <sycl/detail/helpers.hpp> // for Builder
#include <sycl/detail/host_profiling_info.hpp> // for HostProfilingInfo
#include <sycl/detail/item_base.hpp> // for id
#include <sycl/detail/kernel_desc.hpp> // for kernel_param_kind_t
#include <sycl/detail/nd_loop.hpp> // for InitializedVal, NDLoop
#include <sycl/exception.hpp>
#include <sycl/group.hpp> // for group
#include <sycl/h_item.hpp> // for h_item
#include <sycl/id.hpp> // for id
#include <sycl/item.hpp> // for item
#include <sycl/kernel_handler.hpp> // for kernel_handler
#include <sycl/nd_item.hpp> // for nd_item
#include <sycl/nd_range.hpp> // for nd_range
#include <sycl/range.hpp> // for range, operator*
#include <sycl/group.hpp> // for group
#include <sycl/h_item.hpp> // for h_item
#include <sycl/id.hpp> // for id
#include <sycl/item.hpp> // for item
#include <sycl/kernel_handler.hpp> // for kernel_handler
#include <sycl/nd_item.hpp> // for nd_item
#include <sycl/nd_range.hpp> // for nd_range
#include <sycl/range.hpp> // for range, operator*

#include <functional> // for function
#include <stddef.h> // for size_t
Expand Down
Loading
Loading