Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 53 additions & 22 deletions src/ATen/native/xpu/sycl/Reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,48 @@ using at::detail::Array;

namespace detail {

template <class arg_t, class item_t, class CombineFunc, int out_vec_sz = 1>
template <class T, class = void>
struct get_native_sycl_op {
using type = void;
};
template <class T>
struct get_native_sycl_op<T, std::void_t<typename T::native_sycl_op>> {
using type = typename T::native_sycl_op;
};
template <class T>
using native_sycl_op_t = typename get_native_sycl_op<T>::type;

template <class arg_t, class CombineFunc, class NativeOp, int out_vec_sz>
at::detail::Array<arg_t, out_vec_sz> tree_reduce(
sycl::sub_group sg,
at::detail::Array<arg_t, out_vec_sz> value,
CombineFunc combine) {
if constexpr (
!std::is_same<NativeOp, void>::value && std::is_floating_point_v<arg_t>) {
#pragma unroll(out_vec_sz)
for (int i = 0; i < out_vec_sz; ++i) {
value[i] = sycl::reduce_over_group(sg, value[i], NativeOp{});
}
} else {
int sg_size = sg.get_local_range()[0];
for (int offset = 1; offset < sg_size; offset <<= 1) {
#pragma unroll(out_vec_sz)
for (int i = 0; i < out_vec_sz; ++i) {
arg_t other = sycl::shift_group_left(sg, value[i], offset);
value[i] = combine(value[i], other);
}
}
}

return value;
}

template <
class arg_t,
class item_t,
class CombineFunc,
class NativeOp = void,
int out_vec_sz = 1>
inline at::detail::Array<arg_t, out_vec_sz> group_reduce(
item_t item,
int wg_size,
Expand All @@ -59,13 +100,9 @@ inline at::detail::Array<arg_t, out_vec_sz> group_reduce(
SYCL_KERNEL_ASSERT(
wg_size % sg_size == 0 && "unsupported workgroup size for group reduce");

for (int offset = 1; offset < sg_size; offset <<= 1) {
#pragma unroll(out_vec_sz)
for (int i = 0; i < out_vec_sz; ++i) {
arg_t other = sycl::shift_group_left(sg, value[i], offset);
value[i] = combine(value[i], other);
}
}
// tree reduce in subgroup
value =
tree_reduce<arg_t, CombineFunc, NativeOp, out_vec_sz>(sg, value, combine);

if (sg_lid == 0) {
shared_[sg_gid] = value;
Expand All @@ -80,13 +117,8 @@ inline at::detail::Array<arg_t, out_vec_sz> group_reduce(

if (sg_gid == 0 && sg_lid < sg_range) {
value = shared_[sg_lid];
for (int offset = 1; offset < sg_range; offset <<= 1) {
#pragma unroll(out_vec_sz)
for (int i = 0; i < out_vec_sz; ++i) {
arg_t other = sycl::shift_group_left(sg, value[i], offset);
value[i] = combine(value[i], other);
}
}
value = tree_reduce<arg_t, CombineFunc, NativeOp, out_vec_sz>(
sg, value, combine);
}
} else {
// work item tree reduce
Expand Down Expand Up @@ -141,13 +173,7 @@ inline at::detail::Array<arg_t, out_vec_sz> group_x_reduce(
}

// sub-group reduction
for (int offset = 1; offset < dim_x; offset <<= 1) {
#pragma unroll(out_vec_sz)
for (int i = 0; i < out_vec_sz; ++i) {
arg_t other = sycl::shift_group_left(sg, value[i], offset);
value[i] = combine(value[i], other);
}
}
value = tree_reduce<arg_t, CombineFunc, void, out_vec_sz>(sg, value, combine);
return value;
}

Expand Down Expand Up @@ -445,6 +471,8 @@ template <typename out_scalar_t, typename func_t>
struct func_wrapper_t {
using arg_t = typename binary_function_traits<func_t>::arg1_t;
using scalar_t = typename binary_function_traits<func_t>::arg2_t;
// Propagate native_sycl_op from func_t
using native_sycl_op = native_sycl_op_t<func_t>;

func_t combine;
static inline out_scalar_t project(arg_t arg) {
Expand Down Expand Up @@ -570,11 +598,14 @@ struct ReduceOp {
return ops.combine(value, other);
};

using native_op_t = native_sycl_op_t<ops_t>;

if (config.should_group_x_reduce() && config.should_group_y_reduce()) {
value = group_reduce<
arg_t,
decltype(pos),
decltype(combine),
native_op_t,
output_vec_size>(pos, config.num_items, shared, value, combine);
} else {
if (config.should_group_y_reduce()) {
Expand Down
1 change: 1 addition & 0 deletions src/ATen/native/xpu/sycl/ReduceSumProdKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ static void reduce_dispatch(TensorIterator& iter, GeneralDispatcher op) {

template <typename acc_t>
struct SumFunctor {
using native_sycl_op = sycl::plus<acc_t>;
inline acc_t operator()(acc_t a, acc_t b) const {
return a + b;
}
Expand Down
Loading