Skip to content

Commit 6fd1732

Browse files
committed
vectorized IOs for known KVP types
1 parent 6090da0 commit 6fd1732

File tree

3 files changed

+89
-37
lines changed

3 files changed

+89
-37
lines changed

cpp/include/raft/linalg/detail/map.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION.
2+
* SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION.
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

@@ -108,7 +108,7 @@ struct ratio_selector {
108108
template <typename T>
109109
constexpr static auto ignoring_alignment() -> ratio_selector
110110
{
111-
// Types that don't support vectorized I/O (e.g., KeyValuePair) must use ratio=1
111+
// Types without IOType specializations must use ratio=1 (non-vectorized access)
112112
if constexpr (!is_vectorizable_type<T>::value) { return ratio_selector{1, 0}; }
113113

114114
constexpr bool T_evenly_fits_in_cache_line = (kCoalescedVectorSize % sizeof(T)) == 0;

cpp/include/raft/util/vectorized.cuh

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2018-2024, NVIDIA CORPORATION.
2+
* SPDX-FileCopyrightText: Copyright (c) 2018-2026, NVIDIA CORPORATION.
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

66
#pragma once
77

8+
#include <raft/core/kvp.hpp>
89
#include <raft/util/cuda_utils.cuh>
910

1011
#include <cuda_fp16.h>
@@ -13,7 +14,8 @@
1314

1415
namespace raft {
1516

16-
template <typename math_, int VecLen>
17+
// Third parameter enables SFINAE for conditional specializations (e.g., KeyValuePair by size)
18+
template <typename math_, int VecLen, typename Enable = void>
1719
struct IOType {};
1820
template <>
1921
struct IOType<bool, 1> {
@@ -233,12 +235,59 @@ struct IOType<double, 2> {
233235
typedef double2 Type;
234236
};
235237

238+
/**
239+
* Generic IOType specializations for ALL KeyValuePair<K, V> types based on sizeof.
240+
* Uses SFINAE to only enable for sizes that support vectorized I/O.
241+
*
242+
* 4-byte KVP (e.g., <int16_t,int16_t>, <uint8_t,__half>):
243+
* - VecLen=1: int32_t (4 bytes, load 1 KVP)
244+
* - VecLen=2: int2 (8 bytes, load 2 KVPs)
245+
* - VecLen=4: int4 (16 bytes, load 4 KVPs)
246+
*
247+
* 8-byte KVP (e.g., <int,float>, <float,int>, <int,int>, <uint32_t,float>):
248+
* - VecLen=1: int2 (8 bytes, load 1 KVP)
249+
* - VecLen=2: int4 (16 bytes, load 2 KVPs)
250+
*
251+
* 16-byte KVP (e.g., <int64_t,double>, <int,double>, <int64_t,float>):
252+
* - VecLen=1: int4 (16 bytes, load 1 KVP)
253+
*/
254+
255+
template <typename K, typename V>
256+
struct IOType<KeyValuePair<K, V>, 1, std::enable_if_t<sizeof(KeyValuePair<K, V>) == 4>> {
257+
typedef int32_t Type;
258+
};
259+
260+
template <typename K, typename V>
261+
struct IOType<KeyValuePair<K, V>, 2, std::enable_if_t<sizeof(KeyValuePair<K, V>) == 4>> {
262+
typedef int2 Type;
263+
};
264+
265+
template <typename K, typename V>
266+
struct IOType<KeyValuePair<K, V>, 4, std::enable_if_t<sizeof(KeyValuePair<K, V>) == 4>> {
267+
typedef int4 Type;
268+
};
269+
270+
template <typename K, typename V>
271+
struct IOType<KeyValuePair<K, V>, 1, std::enable_if_t<sizeof(KeyValuePair<K, V>) == 8>> {
272+
typedef int2 Type;
273+
};
274+
275+
template <typename K, typename V>
276+
struct IOType<KeyValuePair<K, V>, 2, std::enable_if_t<sizeof(KeyValuePair<K, V>) == 8>> {
277+
typedef int4 Type;
278+
};
279+
280+
template <typename K, typename V>
281+
struct IOType<KeyValuePair<K, V>, 1, std::enable_if_t<sizeof(KeyValuePair<K, V>) == 16>> {
282+
typedef int4 Type;
283+
};
284+
236285
/**
237286
* @brief Type trait to detect if a type supports vectorized I/O operations.
238287
*
239288
* A type is vectorizable if it has IOType<T, 1> specialization.
240-
* Types like KeyValuePair that don't have IOType specializations
241-
* will return false, causing the map functions to use non-vectorized access.
289+
* Common KeyValuePair types have IOType specializations above.
290+
* Custom types without IOType specializations will use non-vectorized access.
242291
*/
243292
template <typename T, typename = void>
244293
struct is_vectorizable_type : std::false_type {};

cpp/tests/linalg/map.cu

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* SPDX-FileCopyrightText: Copyright (c) 2018-2024, NVIDIA CORPORATION.
2+
* SPDX-FileCopyrightText: Copyright (c) 2018-2026, NVIDIA CORPORATION.
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

@@ -17,12 +17,12 @@
1717
#include <raft/random/rng.cuh>
1818
#include <raft/util/cudart_utils.hpp>
1919

20-
#include <gtest/gtest.h>
21-
2220
#include <thrust/device_ptr.h>
2321
#include <thrust/execution_policy.h>
2422
#include <thrust/transform.h>
2523

24+
#include <gtest/gtest.h>
25+
2626
namespace raft {
2727
namespace linalg {
2828

@@ -116,9 +116,13 @@ void mapLaunch(OutType* out,
116116
if constexpr (is_kvp<InType>::value) {
117117
map(handle, out_view, KVPAddOp{scalar}, in1_view, in2_view, in3_view);
118118
} else {
119-
map(handle, out_view,
120-
[=] __device__(InType a, InType b, InType c) { return a + b + c + scalar; },
121-
in1_view, in2_view, in3_view);
119+
map(
120+
handle,
121+
out_view,
122+
[=] __device__(InType a, InType b, InType c) { return a + b + c + scalar; },
123+
in1_view,
124+
in2_view,
125+
in3_view);
122126
}
123127
}
124128

@@ -229,21 +233,21 @@ class MapTest : public ::testing::TestWithParam<MapInputs<InType, IdxType, OutTy
229233
uniform(handle, r, fval2.data(), len, float(-1.0), float(1.0));
230234
uniform(handle, r, fval3.data(), len, float(-1.0), float(1.0));
231235

232-
raft::device_resources local_handle{stream};
236+
raft::device_resources handle{stream};
233237
auto fkey1_view = raft::make_device_vector_view<const float>(fkey1.data(), fkey1.size());
234238
auto fkey2_view = raft::make_device_vector_view<const float>(fkey2.data(), fkey2.size());
235239
auto fkey3_view = raft::make_device_vector_view<const float>(fkey3.data(), fkey3.size());
236240
auto fval1_view = raft::make_device_vector_view<const float>(fval1.data(), fval1.size());
237241
auto fval2_view = raft::make_device_vector_view<const float>(fval2.data(), fval2.size());
238242
auto fval3_view = raft::make_device_vector_view<const float>(fval3.data(), fval3.size());
239-
auto in1_view = raft::make_device_vector_view(in1.data(), in1.size());
240-
auto in2_view = raft::make_device_vector_view(in2.data(), in2.size());
241-
auto in3_view = raft::make_device_vector_view(in3.data(), in3.size());
243+
auto in1_view = raft::make_device_vector_view(in1.data(), in1.size());
244+
auto in2_view = raft::make_device_vector_view(in2.data(), in2.size());
245+
auto in3_view = raft::make_device_vector_view(in3.data(), in3.size());
242246

243247
auto make_kvp = [] __device__(float k, float v) { return KVP{static_cast<int>(k), v}; };
244-
raft::linalg::map(local_handle, in1_view, make_kvp, fkey1_view, fval1_view);
245-
raft::linalg::map(local_handle, in2_view, make_kvp, fkey2_view, fval2_view);
246-
raft::linalg::map(local_handle, in3_view, make_kvp, fkey3_view, fval3_view);
248+
raft::linalg::map(handle, in1_view, make_kvp, fkey1_view, fval1_view);
249+
raft::linalg::map(handle, in2_view, make_kvp, fkey2_view, fval2_view);
250+
raft::linalg::map(handle, in3_view, make_kvp, fkey3_view, fval3_view);
247251
} else {
248252
// For padded_float: first create random float arrays, then convert
249253
rmm::device_uvector<float> fin1(params.len, stream);
@@ -253,7 +257,7 @@ class MapTest : public ::testing::TestWithParam<MapInputs<InType, IdxType, OutTy
253257
uniform(handle, r, fin2.data(), len, float(-1.0), float(1.0));
254258
uniform(handle, r, fin3.data(), len, float(-1.0), float(1.0));
255259

256-
raft::device_resources local_handle{stream};
260+
raft::device_resources handle{stream};
257261
auto fin1_view = raft::make_device_vector_view(fin1.data(), fin1.size());
258262
auto fin2_view = raft::make_device_vector_view(fin2.data(), fin2.size());
259263
auto fin3_view = raft::make_device_vector_view(fin3.data(), fin3.size());
@@ -262,9 +266,9 @@ class MapTest : public ::testing::TestWithParam<MapInputs<InType, IdxType, OutTy
262266
auto in3_view = raft::make_device_vector_view(in3.data(), in3.size());
263267

264268
auto add_padding = [] __device__(float a) { return padded_float(a); };
265-
raft::linalg::map(local_handle, in1_view, add_padding, raft::make_const_mdspan(fin1_view));
266-
raft::linalg::map(local_handle, in2_view, add_padding, raft::make_const_mdspan(fin2_view));
267-
raft::linalg::map(local_handle, in3_view, add_padding, raft::make_const_mdspan(fin3_view));
269+
raft::linalg::map(handle, in1_view, add_padding, raft::make_const_mdspan(fin1_view));
270+
raft::linalg::map(handle, in2_view, add_padding, raft::make_const_mdspan(fin2_view));
271+
raft::linalg::map(handle, in3_view, add_padding, raft::make_const_mdspan(fin3_view));
268272
}
269273

270274
create_ref(out_ref.data(), in1.data(), in2.data(), in3.data(), params.scalar, len, stream);
@@ -292,13 +296,12 @@ class MapOffsetTest : public ::testing::TestWithParam<MapInputs<OutType, IdxType
292296
{
293297
}
294298

295-
// Functor for KVP map_offset test (must be public to avoid extended lambda restriction)
299+
// Functor for KVP map_offset test
296300
struct KVPScaleOp {
297301
OutType scalar;
298302
__device__ OutType operator()(IdxType idx) const
299303
{
300-
return OutType{static_cast<int>(idx) * scalar.key,
301-
static_cast<float>(idx) * scalar.value};
304+
return OutType{static_cast<int>(idx) * scalar.key, static_cast<float>(idx) * scalar.value};
302305
}
303306
};
304307

@@ -409,22 +412,22 @@ struct CompareKVP {
409412
{
410413
// Keys must match exactly, values must be within tolerance
411414
if (a.key != b.key) return false;
412-
float diff = std::abs(a.value - b.value);
413-
float m = std::max(std::abs(a.value), std::abs(b.value));
415+
float diff = std::abs(a.value - b.value);
416+
float m = std::max(std::abs(a.value), std::abs(b.value));
414417
float ratio = diff > eps ? diff / m : diff;
415418
return (ratio <= eps);
416419
}
417420
};
418421

419-
#define MAP_TEST_KVP(test_type, test_name, inputs) \
420-
typedef RAFT_DEPAREN(test_type) test_name; \
421-
TEST_P(test_name, Result) \
422-
{ \
423-
ASSERT_TRUE(devArrMatch(this->out_ref.data(), \
424-
this->out.data(), \
425-
this->params.len, \
426-
CompareKVP(this->params.tolerance.value))); \
427-
} \
422+
#define MAP_TEST_KVP(test_type, test_name, inputs) \
423+
typedef RAFT_DEPAREN(test_type) test_name; \
424+
TEST_P(test_name, Result) \
425+
{ \
426+
ASSERT_TRUE(devArrMatch(this->out_ref.data(), \
427+
this->out.data(), \
428+
this->params.len, \
429+
CompareKVP(this->params.tolerance.value))); \
430+
} \
428431
INSTANTIATE_TEST_SUITE_P(MapTests, test_name, ::testing::ValuesIn(inputs))
429432

430433
const std::vector<MapInputs<KVP, int>> inputs_kvp_i32 = {

0 commit comments

Comments
 (0)