Skip to content

Commit ebaba1a

Browse files
authored
Merge pull request #150 from PointKernel/static-map-retrieve-all
Add `static_map::retrieve_all`
2 parents 8b15f06 + 2793db2 commit ebaba1a

File tree

6 files changed

+178
-6
lines changed

6 files changed

+178
-6
lines changed

include/cuco/detail/static_map.inl

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
2+
* Copyright (c) 2020-2022, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -15,6 +15,11 @@
1515
*/
1616

1717
#include <cuco/detail/bitwise_compare.cuh>
18+
#include <cuco/detail/utils.cuh>
19+
20+
#include <thrust/copy.h>
21+
#include <thrust/iterator/transform_iterator.h>
22+
#include <thrust/iterator/zip_iterator.h>
1823

1924
namespace cuco {
2025

@@ -136,6 +141,25 @@ void static_map<Key, Value, Scope, Allocator>::find(InputIt first,
136141
<<<grid_size, block_size, 0, stream>>>(first, last, output_begin, view, hash, key_equal);
137142
}
138143

144+
template <typename Key, typename Value, cuda::thread_scope Scope, typename Allocator>
145+
template <typename KeyOut, typename ValueOut>
146+
std::pair<KeyOut, ValueOut> static_map<Key, Value, Scope, Allocator>::retrieve_all(
147+
KeyOut keys_out, ValueOut values_out, cudaStream_t stream)
148+
{
149+
static_assert(sizeof(pair_atomic_type) == sizeof(value_type));
150+
auto slots_begin = reinterpret_cast<value_type*>(slots_);
151+
152+
auto begin = thrust::make_transform_iterator(slots_begin, detail::slot_to_tuple<Key, Value>{});
153+
auto end = begin + get_capacity();
154+
auto filled = detail::slot_is_filled<Key>{get_empty_key_sentinel()};
155+
auto zipped_out_begin = thrust::make_zip_iterator(thrust::make_tuple(keys_out, values_out));
156+
157+
auto const zipped_out_end =
158+
thrust::copy_if(thrust::cuda::par.on(stream), begin, end, zipped_out_begin, filled);
159+
auto const num = std::distance(zipped_out_begin, zipped_out_end);
160+
return std::make_pair(keys_out + num, values_out + num);
161+
}
162+
139163
template <typename Key, typename Value, cuda::thread_scope Scope, typename Allocator>
140164
template <typename InputIt, typename OutputIt, typename Hash, typename KeyEqual>
141165
void static_map<Key, Value, Scope, Allocator>::contains(InputIt first,

include/cuco/detail/utils.cuh

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2021, NVIDIA CORPORATION.
2+
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -27,5 +27,30 @@ __device__ __forceinline__ int32_t count_least_significant_bits(uint32_t x, int3
2727
return __popc(x & (1 << n) - 1);
2828
}
2929

30+
/**
31+
* @brief Converts `cuco::pair` to `thrust::tuple` to allow assigning to a zip iterator.
32+
*/
33+
template <typename Key, typename Value>
34+
struct slot_to_tuple {
35+
template <typename S>
36+
__device__ thrust::tuple<Key, Value> operator()(S const& s)
37+
{
38+
return thrust::tuple<Key, Value>(s.first, s.second);
39+
}
40+
};
41+
42+
/**
43+
* @brief Device functor returning whether the input slot `s` is filled.
44+
*/
45+
template <typename Key>
46+
struct slot_is_filled {
47+
Key empty_key_sentinel;
48+
template <typename S>
49+
__device__ bool operator()(S const& s)
50+
{
51+
return thrust::get<0>(s) != empty_key_sentinel;
52+
}
53+
};
54+
3055
} // namespace detail
3156
} // namespace cuco

include/cuco/static_map.cuh

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
2+
* Copyright (c) 2020-2022, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -296,6 +296,29 @@ class static_map {
296296
KeyEqual key_equal = KeyEqual{},
297297
cudaStream_t stream = 0);
298298

299+
/**
300+
* @brief Retrieves all of the keys and their associated values.
301+
*
302+
* The order in which keys are returned is implementation defined and not guaranteed to be
303+
* consistent between subsequent calls to `retrieve_all`.
304+
*
305+
* Behavior is undefined if the range beginning at `keys_out` or `values_out` is less than
306+
* `get_size()`
307+
*
308+
* @tparam KeyOut Device accessible random access output iterator whose `value_type` is
309+
* convertible from `key_type`.
310+
* @tparam ValueOut Device accesible random access output iterator whose `value_type` is
311+
* convertible from `mapped_type`.
312+
* @param keys_out Beginning output iterator for keys
313+
* @param values_out Beginning output iterator for values
314+
* @param stream CUDA stream used for this operation
315+
* @return Pair of iterators indicating the last elements in the output
316+
*/
317+
template <typename KeyOut, typename ValueOut>
318+
std::pair<KeyOut, ValueOut> retrieve_all(KeyOut keys_out,
319+
ValueOut values_out,
320+
cudaStream_t stream = 0);
321+
299322
/**
300323
* @brief Indicates whether the keys in the range
301324
* `[first, last)` are contained in the map.

tests/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#=============================================================================
2-
# Copyright (c) 2018-2021, NVIDIA CORPORATION.
2+
# Copyright (c) 2018-2022, NVIDIA CORPORATION.
33
#
44
# Licensed under the Apache License, Version 2.0 (the "License");
55
# you may not use this file except in compliance with the License.
@@ -55,6 +55,7 @@ endfunction(ConfigureTest)
5555
# - static_map tests ------------------------------------------------------------------------------
5656
ConfigureTest(STATIC_MAP_TEST
5757
static_map/custom_type_test.cu
58+
static_map/duplicate_keys_test.cu
5859
static_map/key_sentinel_test.cu
5960
static_map/shared_memory_test.cu
6061
static_map/stream_test.cu
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
/*
2+
* Copyright (c) 2022, NVIDIA CORPORATION.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include <utils.hpp>
18+
19+
#include <cuco/static_map.cuh>
20+
21+
#include <thrust/device_vector.h>
22+
#include <thrust/functional.h>
23+
#include <thrust/iterator/discard_iterator.h>
24+
#include <thrust/sort.h>
25+
26+
#include <catch2/catch.hpp>
27+
28+
TEMPLATE_TEST_CASE_SIG("Duplicate keys",
29+
"",
30+
((typename Key, typename Value), Key, Value),
31+
(int32_t, int32_t),
32+
(int32_t, int64_t),
33+
(int64_t, int32_t),
34+
(int64_t, int64_t))
35+
{
36+
constexpr std::size_t num_keys{500'000};
37+
cuco::static_map<Key, Value> map{num_keys * 2, -1, -1};
38+
39+
auto m_view = map.get_device_mutable_view();
40+
auto view = map.get_device_view();
41+
42+
thrust::device_vector<Key> d_keys(num_keys);
43+
thrust::device_vector<Value> d_values(num_keys);
44+
45+
thrust::sequence(thrust::device, d_keys.begin(), d_keys.end());
46+
thrust::sequence(thrust::device, d_values.begin(), d_values.end());
47+
48+
auto pairs_begin = thrust::make_transform_iterator(
49+
thrust::make_counting_iterator<int>(0),
50+
[] __device__(auto i) { return cuco::pair_type<Key, Value>(i / 2, i / 2); });
51+
52+
thrust::device_vector<Value> d_results(num_keys);
53+
thrust::device_vector<bool> d_contained(num_keys);
54+
55+
SECTION("Retrieve all entries")
56+
{
57+
auto constexpr gold = num_keys / 2;
58+
thrust::device_vector<Key> unique_keys(gold);
59+
thrust::device_vector<Key> unique_values(gold);
60+
61+
// Retrieve all from an empty map
62+
auto [empty_key_end, empty_value_end] =
63+
map.retrieve_all(unique_keys.begin(), unique_values.begin());
64+
REQUIRE(std::distance(unique_keys.begin(), empty_key_end) == 0);
65+
REQUIRE(std::distance(unique_values.begin(), empty_value_end) == 0);
66+
67+
map.insert(pairs_begin, pairs_begin + num_keys);
68+
69+
auto const num_entries = map.get_size();
70+
REQUIRE(num_entries == gold);
71+
72+
auto [key_out_end, value_out_end] =
73+
map.retrieve_all(unique_keys.begin(), unique_values.begin());
74+
REQUIRE(std::distance(unique_keys.begin(), key_out_end) == gold);
75+
REQUIRE(std::distance(unique_values.begin(), value_out_end) == gold);
76+
77+
thrust::sort(thrust::device, unique_keys.begin(), unique_keys.end());
78+
REQUIRE(cuco::test::equal(unique_keys.begin(),
79+
unique_keys.end(),
80+
thrust::make_counting_iterator<Key>(0),
81+
thrust::equal_to<Key>{}));
82+
}
83+
84+
SECTION("Tests of contains")
85+
{
86+
map.insert(pairs_begin, pairs_begin + num_keys);
87+
map.contains(d_keys.begin(), d_keys.end(), d_contained.begin());
88+
89+
REQUIRE(cuco::test::all_of(d_contained.begin(),
90+
d_contained.begin() + num_keys / 2,
91+
[] __device__(bool const& b) { return b; }));
92+
93+
REQUIRE(cuco::test::none_of(d_contained.begin() + num_keys / 2,
94+
d_contained.end(),
95+
[] __device__(bool const& b) { return b; }));
96+
}
97+
}

tests/utils.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
2+
* Copyright (c) 2020-2022, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -16,9 +16,11 @@
1616

1717
#pragma once
1818

19+
#include <utils.cuh>
20+
1921
#include <thrust/functional.h>
2022

21-
#include <utils.cuh>
23+
#include <cooperative_groups.h>
2224

2325
namespace cuco {
2426
namespace test {

0 commit comments

Comments
 (0)