Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 198 additions & 10 deletions cub/cub/device/device_transform.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ public:
typename... RandomAccessIteratorsOut,
typename NumItemsT,
typename TransformOp,
typename Env = ::cuda::std::execution::env<>>
typename Env = ::cuda::std::execution::env<>,
::cuda::std::enable_if_t<!::cuda::std::is_convertible_v<Env, cudaStream_t>, int> = 0>
CUB_RUNTIME_FUNCTION static cudaError_t Transform(
::cuda::std::tuple<RandomAccessIteratorsIn...> inputs,
::cuda::std::tuple<RandomAccessIteratorsOut...> outputs,
Expand All @@ -201,6 +202,29 @@ public:
}

#ifndef _CCCL_DOXYGEN_INVOKED // Do not document
// we need this so the previous overload is not ambiguous with the next one
static_assert(!::cuda::std::is_convertible_v<::cuda::stream_ref, cudaStream_t>);

// we keep this overload around to support types that are convertible to `cudaStream_t` but not copyable
template <typename... RandomAccessIteratorsIn,
typename... RandomAccessIteratorsOut,
typename NumItemsT,
typename TransformOp>
CUB_RUNTIME_FUNCTION static cudaError_t Transform(
::cuda::std::tuple<RandomAccessIteratorsIn...> inputs,
::cuda::std::tuple<RandomAccessIteratorsOut...> outputs,
NumItemsT num_items,
TransformOp transform_op,
cudaStream_t stream)
{
return Transform(
::cuda::std::move(inputs),
::cuda::std::move(outputs),
num_items,
::cuda::std::move(transform_op),
::cuda::stream_ref{stream});
}

// Overload with additional parameters to specify temporary storage. Provided for compatibility with other CUB APIs.
template <typename... RandomAccessIteratorsIn,
typename... RandomAccessIteratorsOut,
Expand Down Expand Up @@ -261,7 +285,8 @@ public:
typename RandomAccessIteratorOut,
typename NumItemsT,
typename TransformOp,
typename Env = ::cuda::std::execution::env<>>
typename Env = ::cuda::std::execution::env<>,
::cuda::std::enable_if_t<!::cuda::std::is_convertible_v<Env, cudaStream_t>, int> = 0>
CUB_RUNTIME_FUNCTION static cudaError_t Transform(
::cuda::std::tuple<RandomAccessIteratorsIn...> inputs,
RandomAccessIteratorOut output,
Expand All @@ -280,6 +305,23 @@ public:
}

#ifndef _CCCL_DOXYGEN_INVOKED // Do not document
// we keep this overload around to support types that are convertible to `cudaStream_t` but not copyable
template <typename... RandomAccessIteratorsIn, typename RandomAccessIteratorOut, typename NumItemsT, typename TransformOp>
CUB_RUNTIME_FUNCTION static cudaError_t Transform(
::cuda::std::tuple<RandomAccessIteratorsIn...> inputs,
RandomAccessIteratorOut output,
NumItemsT num_items,
TransformOp transform_op,
cudaStream_t stream)
{
return Transform(
::cuda::std::move(inputs),
::cuda::std::move(output),
num_items,
::cuda::std::move(transform_op),
::cuda::stream_ref{stream});
}

// Overload with additional parameters to specify temporary storage. Provided for compatibility with other CUB APIs.
template <typename... RandomAccessIteratorsIn, typename RandomAccessIteratorOut, typename NumItemsT, typename TransformOp>
CUB_RUNTIME_FUNCTION static cudaError_t Transform(
Expand Down Expand Up @@ -325,7 +367,8 @@ public:
typename RandomAccessIteratorOut,
typename NumItemsT,
typename TransformOp,
typename Env = ::cuda::std::execution::env<>>
typename Env = ::cuda::std::execution::env<>,
::cuda::std::enable_if_t<!::cuda::std::is_convertible_v<Env, cudaStream_t>, int> = 0>
CUB_RUNTIME_FUNCTION static cudaError_t Transform(
RandomAccessIteratorIn input,
RandomAccessIteratorOut output,
Expand All @@ -342,6 +385,23 @@ public:
}

#ifndef _CCCL_DOXYGEN_INVOKED // Do not document
// we keep this overload around to support types that are convertible to `cudaStream_t` but not copyable
template <typename RandomAccessIteratorIn, typename RandomAccessIteratorOut, typename NumItemsT, typename TransformOp>
CUB_RUNTIME_FUNCTION static cudaError_t Transform(
RandomAccessIteratorIn input,
RandomAccessIteratorOut output,
NumItemsT num_items,
TransformOp transform_op,
cudaStream_t stream)
{
return Transform(
::cuda::std::make_tuple(::cuda::std::move(input)),
::cuda::std::move(output),
num_items,
::cuda::std::move(transform_op),
::cuda::stream_ref{stream});
}

// Overload with additional parameters to specify temporary storage. Provided for compatibility with other CUB APIs.
template <typename RandomAccessIteratorIn, typename RandomAccessIteratorOut, typename NumItemsT, typename TransformOp>
CUB_RUNTIME_FUNCTION static cudaError_t Transform(
Expand Down Expand Up @@ -387,7 +447,8 @@ public:
template <typename RandomAccessIteratorOut,
typename NumItemsT,
typename Generator,
typename Env = ::cuda::std::execution::env<>>
typename Env = ::cuda::std::execution::env<>,
::cuda::std::enable_if_t<!::cuda::std::is_convertible_v<Env, cudaStream_t>, int> = 0>
CUB_RUNTIME_FUNCTION static cudaError_t
Generate(RandomAccessIteratorOut output, NumItemsT num_items, Generator generator, Env env = {})
{
Expand All @@ -408,6 +469,14 @@ public:
}

#ifndef _CCCL_DOXYGEN_INVOKED // Do not document
// we keep this overload around to support types that are convertible to `cudaStream_t` but not copyable
template <typename RandomAccessIteratorOut, typename NumItemsT, typename Generator>
CUB_RUNTIME_FUNCTION static cudaError_t
Generate(RandomAccessIteratorOut output, NumItemsT num_items, Generator generator, cudaStream_t stream)
{
return Generate(::cuda::std::move(output), num_items, ::cuda::std::move(generator), ::cuda::stream_ref{stream});
}

// Overload with additional parameters to specify temporary storage. Provided for compatibility with other CUB APIs.
template <typename RandomAccessIteratorOut, typename NumItemsT, typename Generator>
CUB_RUNTIME_FUNCTION static cudaError_t Generate(
Expand Down Expand Up @@ -446,7 +515,8 @@ public:
template <typename RandomAccessIteratorOut,
typename NumItemsT,
typename Value,
typename Env = ::cuda::std::execution::env<>>
typename Env = ::cuda::std::execution::env<>,
::cuda::std::enable_if_t<!::cuda::std::is_convertible_v<Env, cudaStream_t>, int> = 0>
CUB_RUNTIME_FUNCTION static cudaError_t
Fill(RandomAccessIteratorOut output, NumItemsT num_items, Value value, Env env = {})
{
Expand All @@ -464,6 +534,14 @@ public:
}

#ifndef _CCCL_DOXYGEN_INVOKED // Do not document
// we keep this overload around to support types that are convertible to `cudaStream_t` but not copyable
template <typename RandomAccessIteratorOut, typename NumItemsT, typename Value>
CUB_RUNTIME_FUNCTION static cudaError_t
Fill(RandomAccessIteratorOut output, NumItemsT num_items, Value value, cudaStream_t stream)
{
return Fill(::cuda::std::move(output), num_items, ::cuda::std::move(value), ::cuda::stream_ref{stream});
}

// Overload with additional parameters to specify temporary storage. Provided for compatibility with other CUB APIs.
template <typename RandomAccessIteratorOut, typename NumItemsT, typename Value>
CUB_RUNTIME_FUNCTION static cudaError_t
Expand Down Expand Up @@ -525,7 +603,8 @@ public:
typename NumItemsT,
typename Predicate,
typename TransformOp,
typename Env = ::cuda::std::execution::env<>>
typename Env = ::cuda::std::execution::env<>,
::cuda::std::enable_if_t<!::cuda::std::is_convertible_v<Env, cudaStream_t>, int> = 0>
CUB_RUNTIME_FUNCTION static cudaError_t TransformIf(
::cuda::std::tuple<RandomAccessIteratorsIn...> inputs,
RandomAccessIteratorOut output,
Expand All @@ -545,6 +624,29 @@ public:
}

#ifndef _CCCL_DOXYGEN_INVOKED // Do not document
// we keep this overload around to support types that are convertible to `cudaStream_t` but not copyable
template <typename... RandomAccessIteratorsIn,
typename RandomAccessIteratorOut,
typename NumItemsT,
typename Predicate,
typename TransformOp>
CUB_RUNTIME_FUNCTION static cudaError_t TransformIf(
::cuda::std::tuple<RandomAccessIteratorsIn...> inputs,
RandomAccessIteratorOut output,
NumItemsT num_items,
Predicate predicate,
TransformOp transform_op,
cudaStream_t stream)
{
return TransformIf(
::cuda::std::move(inputs),
::cuda::std::move(output),
num_items,
::cuda::std::move(predicate),
::cuda::std::move(transform_op),
::cuda::stream_ref{stream});
}

// Overload with additional parameters to specify temporary storage. Provided for compatibility with other CUB APIs.
template <typename... RandomAccessIteratorsIn,
typename RandomAccessIteratorOut,
Expand Down Expand Up @@ -616,7 +718,8 @@ public:
typename NumItemsT,
typename Predicate,
typename TransformOp,
typename Env = ::cuda::std::execution::env<>>
typename Env = ::cuda::std::execution::env<>,
::cuda::std::enable_if_t<!::cuda::std::is_convertible_v<Env, cudaStream_t>, int> = 0>
CUB_RUNTIME_FUNCTION static cudaError_t TransformIf(
RandomAccessIteratorIn input,
RandomAccessIteratorOut output,
Expand All @@ -635,6 +738,29 @@ public:
}

#ifndef _CCCL_DOXYGEN_INVOKED // Do not document
// we keep this overload around to support types that are convertible to `cudaStream_t` but not copyable
template <typename RandomAccessIteratorIn,
typename RandomAccessIteratorOut,
typename NumItemsT,
typename Predicate,
typename TransformOp>
CUB_RUNTIME_FUNCTION static cudaError_t TransformIf(
RandomAccessIteratorIn input,
RandomAccessIteratorOut output,
NumItemsT num_items,
Predicate predicate,
TransformOp transform_op,
cudaStream_t stream)
{
return TransformIf(
::cuda::std::make_tuple(::cuda::std::move(input)),
::cuda::std::move(output),
num_items,
::cuda::std::move(predicate),
::cuda::std::move(transform_op),
::cuda::stream_ref{stream});
}

// Overload with additional parameters to specify temporary storage. Provided for compatibility with other CUB APIs.
template <typename RandomAccessIteratorIn,
typename RandomAccessIteratorOut,
Expand Down Expand Up @@ -702,7 +828,8 @@ public:
typename RandomAccessIteratorOut,
typename NumItemsT,
typename TransformOp,
typename Env = ::cuda::std::execution::env<>>
typename Env = ::cuda::std::execution::env<>,
::cuda::std::enable_if_t<!::cuda::std::is_convertible_v<Env, cudaStream_t>, int> = 0>
CUB_RUNTIME_FUNCTION static cudaError_t TransformStableArgumentAddresses(
::cuda::std::tuple<RandomAccessIteratorsIn...> inputs,
RandomAccessIteratorOut output,
Expand All @@ -721,6 +848,23 @@ public:
}

#ifndef _CCCL_DOXYGEN_INVOKED // Do not document
// we keep this overload around to support types that are convertible to `cudaStream_t` but not copyable
template <typename... RandomAccessIteratorsIn, typename RandomAccessIteratorOut, typename NumItemsT, typename TransformOp>
CUB_RUNTIME_FUNCTION static cudaError_t TransformStableArgumentAddresses(
::cuda::std::tuple<RandomAccessIteratorsIn...> inputs,
RandomAccessIteratorOut output,
NumItemsT num_items,
TransformOp transform_op,
cudaStream_t stream)
{
return TransformStableArgumentAddresses(
::cuda::std::move(inputs),
::cuda::std::move(output),
num_items,
::cuda::std::move(transform_op),
::cuda::stream_ref{stream});
}

template <typename... RandomAccessIteratorsIn, typename RandomAccessIteratorOut, typename NumItemsT, typename TransformOp>
CUB_RUNTIME_FUNCTION static cudaError_t TransformStableArgumentAddresses(
void* d_temp_storage,
Expand Down Expand Up @@ -765,7 +909,8 @@ public:
typename RandomAccessIteratorOut,
typename NumItemsT,
typename TransformOp,
typename Env = ::cuda::std::execution::env<>>
typename Env = ::cuda::std::execution::env<>,
::cuda::std::enable_if_t<!::cuda::std::is_convertible_v<Env, cudaStream_t>, int> = 0>
CUB_RUNTIME_FUNCTION static cudaError_t TransformStableArgumentAddresses(
RandomAccessIteratorIn input,
RandomAccessIteratorOut output,
Expand All @@ -782,6 +927,23 @@ public:
}

#ifndef _CCCL_DOXYGEN_INVOKED // Do not document
// we keep this overload around to support types that are convertible to `cudaStream_t` but not copyable
template <typename RandomAccessIteratorIn, typename RandomAccessIteratorOut, typename NumItemsT, typename TransformOp>
CUB_RUNTIME_FUNCTION static cudaError_t TransformStableArgumentAddresses(
RandomAccessIteratorIn input,
RandomAccessIteratorOut output,
NumItemsT num_items,
TransformOp transform_op,
cudaStream_t stream)
{
return TransformStableArgumentAddresses(
::cuda::std::make_tuple(::cuda::std::move(input)),
::cuda::std::move(output),
num_items,
::cuda::std::move(transform_op),
::cuda::stream_ref{stream});
}

template <typename RandomAccessIteratorIn, typename RandomAccessIteratorOut, typename NumItemsT, typename TransformOp>
CUB_RUNTIME_FUNCTION static cudaError_t TransformStableArgumentAddresses(
void* d_temp_storage,
Expand Down Expand Up @@ -813,7 +975,8 @@ public:
typename NumItemsT,
typename Predicate,
typename TransformOp,
typename Env = ::cuda::std::execution::env<>>
typename Env = ::cuda::std::execution::env<>,
::cuda::std::enable_if_t<!::cuda::std::is_convertible_v<Env, cudaStream_t>, int> = 0>
CUB_RUNTIME_FUNCTION static cudaError_t __transform_if_stable_argument_addresses(
::cuda::std::tuple<RandomAccessIteratorsIn...> inputs,
RandomAccessIteratorOut output,
Expand All @@ -831,6 +994,31 @@ public:
::cuda::std::move(transform_op),
::cuda::std::move(env));
}

#ifndef _CCCL_DOXYGEN_INVOKED // Do not document
// we keep this overload around to support types that are convertible to `cudaStream_t` but not copyable
template <typename... RandomAccessIteratorsIn,
typename RandomAccessIteratorOut,
typename NumItemsT,
typename Predicate,
typename TransformOp>
CUB_RUNTIME_FUNCTION static cudaError_t __transform_if_stable_argument_addresses(
::cuda::std::tuple<RandomAccessIteratorsIn...> inputs,
RandomAccessIteratorOut output,
NumItemsT num_items,
Predicate predicate,
TransformOp transform_op,
cudaStream_t stream)
{
return __transform_if_stable_argument_addresses(
::cuda::std::move(inputs),
::cuda::std::move(output),
num_items,
::cuda::std::move(predicate),
::cuda::std::move(transform_op),
::cuda::stream_ref{stream});
}
#endif // _CCCL_DOXYGEN_INVOKED
};

CUB_NAMESPACE_END
22 changes: 22 additions & 0 deletions cub/test/catch2_test_device_transform_env.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,24 @@ struct stream_convertible
return stream;
}
};
struct stream_convertible_non_copyable
{
cudaStream_t stream;

stream_convertible_non_copyable(cudaStream_t stream)
: stream(stream)
{}

stream_convertible_non_copyable(const stream_convertible_non_copyable&) = delete;
auto operator=(const stream_convertible_non_copyable&) -> stream_convertible_non_copyable& = delete;
stream_convertible_non_copyable(stream_convertible_non_copyable&&) = default;
auto operator=(stream_convertible_non_copyable&&) -> stream_convertible_non_copyable& = default;

operator cudaStream_t() const noexcept
{
return stream;
}
};

struct with_stream_method
{
Expand Down Expand Up @@ -64,6 +82,10 @@ void check_graph_nodes_with_different_streams(F call_cub_api)
{
call_cub_api(stream_convertible{stream.get()});
}
SECTION("stream_convertible_non_copyable")
{
call_cub_api(stream_convertible_non_copyable{stream.get()});
}
SECTION("with_stream_method")
{
call_cub_api(with_stream_method{stream.get()});
Expand Down
Loading