11/*
2- * SPDX-FileCopyrightText: Copyright (c) 2018-2024 , NVIDIA CORPORATION.
2+ * SPDX-FileCopyrightText: Copyright (c) 2018-2026 , NVIDIA CORPORATION.
33 * SPDX-License-Identifier: Apache-2.0
44 */
55
1717#include < raft/random/rng.cuh>
1818#include < raft/util/cudart_utils.hpp>
1919
20- #include < gtest/gtest.h>
21-
2220#include < thrust/device_ptr.h>
2321#include < thrust/execution_policy.h>
2422#include < thrust/transform.h>
2523
24+ #include < gtest/gtest.h>
25+
2626namespace raft {
2727namespace linalg {
2828
@@ -116,9 +116,13 @@ void mapLaunch(OutType* out,
116116 if constexpr (is_kvp<InType>::value) {
117117 map (handle, out_view, KVPAddOp{scalar}, in1_view, in2_view, in3_view);
118118 } else {
119- map (handle, out_view,
120- [=] __device__ (InType a, InType b, InType c) { return a + b + c + scalar; },
121- in1_view, in2_view, in3_view);
119+ map (
120+ handle,
121+ out_view,
122+ [=] __device__ (InType a, InType b, InType c) { return a + b + c + scalar; },
123+ in1_view,
124+ in2_view,
125+ in3_view);
122126 }
123127}
124128
@@ -229,21 +233,21 @@ class MapTest : public ::testing::TestWithParam<MapInputs<InType, IdxType, OutTy
229233 uniform (handle, r, fval2.data (), len, float (-1.0 ), float (1.0 ));
230234 uniform (handle, r, fval3.data (), len, float (-1.0 ), float (1.0 ));
231235
232- raft::device_resources local_handle {stream};
236+ raft::device_resources handle {stream};
233237 auto fkey1_view = raft::make_device_vector_view<const float >(fkey1.data (), fkey1.size ());
234238 auto fkey2_view = raft::make_device_vector_view<const float >(fkey2.data (), fkey2.size ());
235239 auto fkey3_view = raft::make_device_vector_view<const float >(fkey3.data (), fkey3.size ());
236240 auto fval1_view = raft::make_device_vector_view<const float >(fval1.data (), fval1.size ());
237241 auto fval2_view = raft::make_device_vector_view<const float >(fval2.data (), fval2.size ());
238242 auto fval3_view = raft::make_device_vector_view<const float >(fval3.data (), fval3.size ());
239- auto in1_view = raft::make_device_vector_view (in1.data (), in1.size ());
240- auto in2_view = raft::make_device_vector_view (in2.data (), in2.size ());
241- auto in3_view = raft::make_device_vector_view (in3.data (), in3.size ());
243+ auto in1_view = raft::make_device_vector_view (in1.data (), in1.size ());
244+ auto in2_view = raft::make_device_vector_view (in2.data (), in2.size ());
245+ auto in3_view = raft::make_device_vector_view (in3.data (), in3.size ());
242246
243247 auto make_kvp = [] __device__ (float k, float v) { return KVP{static_cast <int >(k), v}; };
244- raft::linalg::map (local_handle , in1_view, make_kvp, fkey1_view, fval1_view);
245- raft::linalg::map (local_handle , in2_view, make_kvp, fkey2_view, fval2_view);
246- raft::linalg::map (local_handle , in3_view, make_kvp, fkey3_view, fval3_view);
248+ raft::linalg::map (handle , in1_view, make_kvp, fkey1_view, fval1_view);
249+ raft::linalg::map (handle , in2_view, make_kvp, fkey2_view, fval2_view);
250+ raft::linalg::map (handle , in3_view, make_kvp, fkey3_view, fval3_view);
247251 } else {
248252 // For padded_float: first create random float arrays, then convert
249253 rmm::device_uvector<float > fin1 (params.len , stream);
@@ -253,7 +257,7 @@ class MapTest : public ::testing::TestWithParam<MapInputs<InType, IdxType, OutTy
253257 uniform (handle, r, fin2.data (), len, float (-1.0 ), float (1.0 ));
254258 uniform (handle, r, fin3.data (), len, float (-1.0 ), float (1.0 ));
255259
256- raft::device_resources local_handle {stream};
260+ raft::device_resources handle {stream};
257261 auto fin1_view = raft::make_device_vector_view (fin1.data (), fin1.size ());
258262 auto fin2_view = raft::make_device_vector_view (fin2.data (), fin2.size ());
259263 auto fin3_view = raft::make_device_vector_view (fin3.data (), fin3.size ());
@@ -262,9 +266,9 @@ class MapTest : public ::testing::TestWithParam<MapInputs<InType, IdxType, OutTy
262266 auto in3_view = raft::make_device_vector_view (in3.data (), in3.size ());
263267
264268 auto add_padding = [] __device__ (float a) { return padded_float (a); };
265- raft::linalg::map (local_handle , in1_view, add_padding, raft::make_const_mdspan (fin1_view));
266- raft::linalg::map (local_handle , in2_view, add_padding, raft::make_const_mdspan (fin2_view));
267- raft::linalg::map (local_handle , in3_view, add_padding, raft::make_const_mdspan (fin3_view));
269+ raft::linalg::map (handle , in1_view, add_padding, raft::make_const_mdspan (fin1_view));
270+ raft::linalg::map (handle , in2_view, add_padding, raft::make_const_mdspan (fin2_view));
271+ raft::linalg::map (handle , in3_view, add_padding, raft::make_const_mdspan (fin3_view));
268272 }
269273
270274 create_ref (out_ref.data (), in1.data (), in2.data (), in3.data (), params.scalar , len, stream);
@@ -292,13 +296,12 @@ class MapOffsetTest : public ::testing::TestWithParam<MapInputs<OutType, IdxType
292296 {
293297 }
294298
295- // Functor for KVP map_offset test (must be public to avoid extended lambda restriction)
299+ // Functor for KVP map_offset test
296300 struct KVPScaleOp {
297301 OutType scalar;
298302 __device__ OutType operator ()(IdxType idx) const
299303 {
300- return OutType{static_cast <int >(idx) * scalar.key ,
301- static_cast <float >(idx) * scalar.value };
304+ return OutType{static_cast <int >(idx) * scalar.key , static_cast <float >(idx) * scalar.value };
302305 }
303306 };
304307
@@ -409,22 +412,22 @@ struct CompareKVP {
409412 {
410413 // Keys must match exactly, values must be within tolerance
411414 if (a.key != b.key ) return false ;
412- float diff = std::abs (a.value - b.value );
413- float m = std::max (std::abs (a.value ), std::abs (b.value ));
415+ float diff = std::abs (a.value - b.value );
416+ float m = std::max (std::abs (a.value ), std::abs (b.value ));
414417 float ratio = diff > eps ? diff / m : diff;
415418 return (ratio <= eps);
416419 }
417420};
418421
419- #define MAP_TEST_KVP (test_type, test_name, inputs ) \
420- typedef RAFT_DEPAREN (test_type) test_name; \
421- TEST_P (test_name, Result) \
422- { \
423- ASSERT_TRUE (devArrMatch (this ->out_ref .data (), \
424- this ->out .data (), \
425- this ->params .len , \
426- CompareKVP (this ->params .tolerance .value ))); \
427- } \
422+ #define MAP_TEST_KVP (test_type, test_name, inputs ) \
423+ typedef RAFT_DEPAREN (test_type) test_name; \
424+ TEST_P (test_name, Result) \
425+ { \
426+ ASSERT_TRUE (devArrMatch (this ->out_ref .data (), \
427+ this ->out .data (), \
428+ this ->params .len , \
429+ CompareKVP (this ->params .tolerance .value ))); \
430+ } \
428431 INSTANTIATE_TEST_SUITE_P (MapTests, test_name, ::testing::ValuesIn(inputs))
429432
430433const std::vector<MapInputs<KVP, int >> inputs_kvp_i32 = {
0 commit comments