Skip to content

Commit f07ec91

Browse files
committed
Add env DeviceMerge
1 parent e4e153b commit f07ec91

File tree

3 files changed

+81
-81
lines changed

3 files changed

+81
-81
lines changed

cub/cub/device/device_merge.cuh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,8 @@ struct DeviceMerge
157157
template <typename KeyIteratorIn1,
158158
typename KeyIteratorIn2,
159159
typename KeyIteratorOut,
160-
typename CompareOp = ::cuda::std::less<>,
161-
typename EnvT = ::cuda::std::execution::env<>,
160+
typename CompareOp = ::cuda::std::less<>,
161+
typename EnvT = ::cuda::std::execution::env<>,
162162
::cuda::std::enable_if_t<!::cuda::std::is_same_v<KeyIteratorIn1, void*>
163163
&& !::cuda::std::is_same_v<KeyIteratorIn1, ::cuda::std::nullptr_t>,
164164
int> = 0>
@@ -351,8 +351,8 @@ struct DeviceMerge
351351
typename ValueIteratorIn2,
352352
typename KeyIteratorOut,
353353
typename ValueIteratorOut,
354-
typename CompareOp = ::cuda::std::less<>,
355-
typename EnvT = ::cuda::std::execution::env<>,
354+
typename CompareOp = ::cuda::std::less<>,
355+
typename EnvT = ::cuda::std::execution::env<>,
356356
::cuda::std::enable_if_t<!::cuda::std::is_same_v<KeyIteratorIn1, void*>
357357
&& !::cuda::std::is_same_v<KeyIteratorIn1, ::cuda::std::nullptr_t>,
358358
int> = 0>

cub/test/catch2_test_device_merge_env.cu

Lines changed: 75 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,10 @@ TEST_CASE("DeviceMerge::MergeKeys works with default environment", "[merge][devi
3232
auto keys2 = c2h::device_vector<int>{0, 3, 3, 4};
3333
auto result = c2h::device_vector<int>(7);
3434

35-
REQUIRE(cudaSuccess
36-
== cub::DeviceMerge::MergeKeys(
37-
keys1.begin(),
38-
static_cast<int>(keys1.size()),
39-
keys2.begin(),
40-
static_cast<int>(keys2.size()),
41-
result.begin()));
35+
REQUIRE(
36+
cudaSuccess
37+
== cub::DeviceMerge::MergeKeys(
38+
keys1.begin(), static_cast<int>(keys1.size()), keys2.begin(), static_cast<int>(keys2.size()), result.begin()));
4239

4340
c2h::device_vector<int> expected{0, 0, 2, 3, 3, 4, 5};
4441
REQUIRE(result == expected);
@@ -54,16 +51,17 @@ TEST_CASE("DeviceMerge::MergePairs works with default environment", "[merge][dev
5451
auto result_keys = c2h::device_vector<int>(7);
5552
auto result_values = c2h::device_vector<char>(7);
5653

57-
REQUIRE(cudaSuccess
58-
== cub::DeviceMerge::MergePairs(
59-
keys1.begin(),
60-
values1.begin(),
61-
static_cast<int>(keys1.size()),
62-
keys2.begin(),
63-
values2.begin(),
64-
static_cast<int>(keys2.size()),
65-
result_keys.begin(),
66-
result_values.begin()));
54+
REQUIRE(
55+
cudaSuccess
56+
== cub::DeviceMerge::MergePairs(
57+
keys1.begin(),
58+
values1.begin(),
59+
static_cast<int>(keys1.size()),
60+
keys2.begin(),
61+
values2.begin(),
62+
static_cast<int>(keys2.size()),
63+
result_keys.begin(),
64+
result_values.begin()));
6765

6866
c2h::device_vector<int> expected_keys{0, 0, 2, 3, 3, 4, 5};
6967
c2h::device_vector<char> expected_values{'a', 'A', 'b', 'B', 'C', 'D', 'c'};
@@ -80,26 +78,26 @@ C2H_TEST("DeviceMerge::MergeKeys uses environment", "[merge][device]")
8078
auto result = c2h::device_vector<int>(7);
8179

8280
size_t expected_bytes_allocated{};
83-
REQUIRE(cudaSuccess
84-
== cub::DeviceMerge::MergeKeys(
85-
nullptr,
86-
expected_bytes_allocated,
87-
keys1.begin(),
88-
static_cast<int>(keys1.size()),
89-
keys2.begin(),
90-
static_cast<int>(keys2.size()),
91-
result.begin()));
81+
REQUIRE(
82+
cudaSuccess
83+
== cub::DeviceMerge::MergeKeys(
84+
nullptr,
85+
expected_bytes_allocated,
86+
keys1.begin(),
87+
static_cast<int>(keys1.size()),
88+
keys2.begin(),
89+
static_cast<int>(keys2.size()),
90+
result.begin()));
9291

9392
auto env = stdexec::env{expected_allocation_size(expected_bytes_allocated)};
9493

95-
merge_keys(
96-
keys1.begin(),
97-
static_cast<int>(keys1.size()),
98-
keys2.begin(),
99-
static_cast<int>(keys2.size()),
100-
result.begin(),
101-
cuda::std::less<>{},
102-
env);
94+
merge_keys(keys1.begin(),
95+
static_cast<int>(keys1.size()),
96+
keys2.begin(),
97+
static_cast<int>(keys2.size()),
98+
result.begin(),
99+
cuda::std::less<>{},
100+
env);
103101

104102
c2h::device_vector<int> expected{0, 0, 2, 3, 3, 4, 5};
105103
REQUIRE(result == expected);
@@ -115,27 +113,27 @@ TEST_CASE("DeviceMerge::MergeKeys uses custom stream", "[merge][device]")
115113
REQUIRE(cudaSuccess == cudaStreamCreate(&custom_stream));
116114

117115
size_t expected_bytes_allocated{};
118-
REQUIRE(cudaSuccess
119-
== cub::DeviceMerge::MergeKeys(
120-
nullptr,
121-
expected_bytes_allocated,
122-
keys1.begin(),
123-
static_cast<int>(keys1.size()),
124-
keys2.begin(),
125-
static_cast<int>(keys2.size()),
126-
result.begin()));
116+
REQUIRE(
117+
cudaSuccess
118+
== cub::DeviceMerge::MergeKeys(
119+
nullptr,
120+
expected_bytes_allocated,
121+
keys1.begin(),
122+
static_cast<int>(keys1.size()),
123+
keys2.begin(),
124+
static_cast<int>(keys2.size()),
125+
result.begin()));
127126

128127
auto stream_prop = stdexec::prop{cuda::get_stream_t{}, cuda::stream_ref{custom_stream}};
129128
auto env = stdexec::env{stream_prop, expected_allocation_size(expected_bytes_allocated)};
130129

131-
merge_keys(
132-
keys1.begin(),
133-
static_cast<int>(keys1.size()),
134-
keys2.begin(),
135-
static_cast<int>(keys2.size()),
136-
result.begin(),
137-
cuda::std::less<>{},
138-
env);
130+
merge_keys(keys1.begin(),
131+
static_cast<int>(keys1.size()),
132+
keys2.begin(),
133+
static_cast<int>(keys2.size()),
134+
result.begin(),
135+
cuda::std::less<>{},
136+
env);
139137

140138
REQUIRE(cudaSuccess == cudaStreamSynchronize(custom_stream));
141139

@@ -156,18 +154,19 @@ C2H_TEST("DeviceMerge::MergePairs uses environment", "[merge][device]")
156154
auto result_values = c2h::device_vector<char>(7);
157155

158156
size_t expected_bytes_allocated{};
159-
REQUIRE(cudaSuccess
160-
== cub::DeviceMerge::MergePairs(
161-
nullptr,
162-
expected_bytes_allocated,
163-
keys1.begin(),
164-
values1.begin(),
165-
static_cast<int>(keys1.size()),
166-
keys2.begin(),
167-
values2.begin(),
168-
static_cast<int>(keys2.size()),
169-
result_keys.begin(),
170-
result_values.begin()));
157+
REQUIRE(
158+
cudaSuccess
159+
== cub::DeviceMerge::MergePairs(
160+
nullptr,
161+
expected_bytes_allocated,
162+
keys1.begin(),
163+
values1.begin(),
164+
static_cast<int>(keys1.size()),
165+
keys2.begin(),
166+
values2.begin(),
167+
static_cast<int>(keys2.size()),
168+
result_keys.begin(),
169+
result_values.begin()));
171170

172171
auto env = stdexec::env{expected_allocation_size(expected_bytes_allocated)};
173172

@@ -203,18 +202,19 @@ TEST_CASE("DeviceMerge::MergePairs uses custom stream", "[merge][device]")
203202
REQUIRE(cudaSuccess == cudaStreamCreate(&custom_stream));
204203

205204
size_t expected_bytes_allocated{};
206-
REQUIRE(cudaSuccess
207-
== cub::DeviceMerge::MergePairs(
208-
nullptr,
209-
expected_bytes_allocated,
210-
keys1.begin(),
211-
values1.begin(),
212-
static_cast<int>(keys1.size()),
213-
keys2.begin(),
214-
values2.begin(),
215-
static_cast<int>(keys2.size()),
216-
result_keys.begin(),
217-
result_values.begin()));
205+
REQUIRE(
206+
cudaSuccess
207+
== cub::DeviceMerge::MergePairs(
208+
nullptr,
209+
expected_bytes_allocated,
210+
keys1.begin(),
211+
values1.begin(),
212+
static_cast<int>(keys1.size()),
213+
keys2.begin(),
214+
values2.begin(),
215+
static_cast<int>(keys2.size()),
216+
result_keys.begin(),
217+
result_values.begin()));
218218

219219
auto stream_prop = stdexec::prop{cuda::get_stream_t{}, cuda::stream_ref{custom_stream}};
220220
auto env = stdexec::env{stream_prop, expected_allocation_size(expected_bytes_allocated)};

cub/test/catch2_test_device_merge_env_api.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
C2H_TEST("cub::DeviceMerge::MergeKeys accepts env with stream", "[merge][env]")
1818
{
1919
// example-begin merge-keys-env
20-
auto keys1 = thrust::device_vector<int>{0, 2, 5};
21-
auto keys2 = thrust::device_vector<int>{0, 3, 3, 4};
20+
auto keys1 = thrust::device_vector<int>{0, 2, 5};
21+
auto keys2 = thrust::device_vector<int>{0, 3, 3, 4};
2222
auto result = thrust::device_vector<int>(7);
2323

2424
cuda::stream stream{cuda::devices[0]};

0 commit comments

Comments
 (0)