Skip to content

Commit c6559fb

Browse files
committed
cub: test DeviceTransform aligned_size_t vectorized store (same-width/narrowing/widening)
1 parent da1fc05 commit c6559fb

1 file changed

Lines changed: 103 additions & 0 deletions

File tree

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
// SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
3+
4+
#include "insert_nested_NVTX_range_guard.h"
5+
6+
#include <cub/device/device_transform.cuh>
7+
8+
#include <cuda/__memory/aligned_size.h>
9+
10+
#include <algorithm>
11+
12+
#include "catch2_test_launch_helper.h"
13+
#include <c2h/catch2_test_helper.h>
14+
#include <c2h/test_util_vec.h>
15+
16+
// %PARAM% TEST_LAUNCH lid 0:1:2
17+
18+
DECLARE_LAUNCH_WRAPPER(cub::DeviceTransform::Transform, transform_many);
19+
20+
#define ALIGNED_ITEM_COUNTS 0, 16, 128, 4096, 4112, 65536, 99'984
21+
22+
template <typename Out>
23+
struct cast_to
24+
{
25+
template <typename T>
26+
__host__ __device__ Out operator()(T v) const
27+
{
28+
return static_cast<Out>(v);
29+
}
30+
};
31+
32+
C2H_TEST("DeviceTransform::Transform aligned_size_t<16> same-width",
33+
"[device][transform]",
34+
c2h::type_list<std::uint8_t, std::uint16_t, std::uint32_t, std::uint64_t, float, double, uchar3>)
35+
{
36+
using type = c2h::get<0, TestType>;
37+
using offset_t = cuda::std::int64_t;
38+
const offset_t num_items = GENERATE(ALIGNED_ITEM_COUNTS);
39+
CAPTURE(c2h::type_name<type>(), num_items);
40+
41+
c2h::device_vector<type> a(num_items, thrust::no_init);
42+
c2h::device_vector<type> b(num_items, thrust::no_init);
43+
c2h::gen(C2H_SEED(1), a);
44+
c2h::gen(C2H_SEED(1), b);
45+
46+
c2h::device_vector<type> result(num_items, thrust::no_init);
47+
transform_many(cuda::std::make_tuple(a.begin(), b.begin()),
48+
result.begin(),
49+
cuda::aligned_size_t<16>(num_items),
50+
cuda::std::plus<type>{});
51+
52+
c2h::host_vector<type> a_h = a;
53+
c2h::host_vector<type> b_h = b;
54+
c2h::host_vector<type> reference_h(num_items, thrust::no_init);
55+
std::transform(a_h.begin(), a_h.end(), b_h.begin(), reference_h.begin(), cuda::std::plus<type>{});
56+
REQUIRE(reference_h == result);
57+
}
58+
59+
C2H_TEST("DeviceTransform::Transform aligned_size_t<16> narrowing to uint8",
60+
"[device][transform]",
61+
c2h::type_list<std::uint16_t, std::uint32_t, std::uint64_t>)
62+
{
63+
using in_t = c2h::get<0, TestType>;
64+
using out_t = std::uint8_t;
65+
using offset_t = cuda::std::int64_t;
66+
const offset_t num_items = GENERATE(ALIGNED_ITEM_COUNTS);
67+
CAPTURE(c2h::type_name<in_t>(), num_items);
68+
69+
c2h::device_vector<in_t> in(num_items, thrust::no_init);
70+
c2h::gen(C2H_SEED(1), in);
71+
72+
c2h::device_vector<out_t> result(num_items, thrust::no_init);
73+
transform_many(
74+
cuda::std::make_tuple(in.begin()), result.begin(), cuda::aligned_size_t<16>(num_items), cast_to<out_t>{});
75+
76+
c2h::host_vector<in_t> in_h = in;
77+
c2h::host_vector<out_t> reference_h(num_items, thrust::no_init);
78+
std::transform(in_h.begin(), in_h.end(), reference_h.begin(), cast_to<out_t>{});
79+
REQUIRE(reference_h == result);
80+
}
81+
82+
C2H_TEST("DeviceTransform::Transform aligned_size_t<16> widening from uint8",
83+
"[device][transform]",
84+
c2h::type_list<std::uint16_t, std::uint32_t, std::uint64_t>)
85+
{
86+
using in_t = std::uint8_t;
87+
using out_t = c2h::get<0, TestType>;
88+
using offset_t = cuda::std::int64_t;
89+
const offset_t num_items = GENERATE(ALIGNED_ITEM_COUNTS);
90+
CAPTURE(c2h::type_name<out_t>(), num_items);
91+
92+
c2h::device_vector<in_t> in(num_items, thrust::no_init);
93+
c2h::gen(C2H_SEED(1), in);
94+
95+
c2h::device_vector<out_t> result(num_items, thrust::no_init);
96+
transform_many(
97+
cuda::std::make_tuple(in.begin()), result.begin(), cuda::aligned_size_t<16>(num_items), cast_to<out_t>{});
98+
99+
c2h::host_vector<in_t> in_h = in;
100+
c2h::host_vector<out_t> reference_h(num_items, thrust::no_init);
101+
std::transform(in_h.begin(), in_h.end(), reference_h.begin(), cast_to<out_t>{});
102+
REQUIRE(reference_h == result);
103+
}

0 commit comments

Comments
 (0)