@@ -321,6 +321,29 @@ auto thrustDeviceAlloc = [](auto alignment, std::size_t size)
321
321
return p;
322
322
};
323
323
324
+ template <typename View>
325
+ struct ViewIteratorAt
326
+ {
327
+ View view;
328
+
329
+ LLAMA_FN_HOST_ACC_INLINE auto operator ()(std::size_t i)
330
+ {
331
+ return *(view.begin () + i);
332
+ }
333
+ };
334
+
335
+ template <typename View>
336
+ auto viewIteratorAt (View& view, std::size_t index)
337
+ {
338
+ ViewIteratorAt<View> t{view};
339
+ using ViewTransformIterator = thrust::transform_iterator<
340
+ decltype (t),
341
+ thrust::counting_iterator<std::size_t >,
342
+ typename View::iterator::reference,
343
+ typename View::iterator::value_type>;
344
+ return ViewTransformIterator{thrust::counting_iterator<std::size_t >{index }, t};
345
+ }
346
+
324
347
template <int Mapping>
325
348
void run (std::ostream& plotFile)
326
349
{
@@ -374,15 +397,8 @@ void run(std::ostream& plotFile)
374
397
std::cout << mappingName << ' \n ' ;
375
398
376
399
auto view = llama::allocView (mapping, thrustDeviceAlloc);
377
-
378
- auto makeViewIteratorFromIndexCreator = [](decltype (view) view)
379
- { return [view] __host__ __device__ (std::size_t i) mutable { return *(view.begin () + i); }; };
380
- auto b = thrust::make_transform_iterator (
381
- thrust::counting_iterator<std::size_t >{0 },
382
- makeViewIteratorFromIndexCreator (view));
383
- auto e = thrust::make_transform_iterator (
384
- thrust::counting_iterator<std::size_t >{N},
385
- makeViewIteratorFromIndexCreator (view));
400
+ auto b = viewIteratorAt (view, 0 );
401
+ auto e = viewIteratorAt (view, N);
386
402
// auto b = view.begin();
387
403
// auto e = view.end();
388
404
@@ -541,9 +557,7 @@ void run(std::ostream& plotFile)
541
557
542
558
{
543
559
auto dstView = llama::allocView (mapping, thrustDeviceAlloc);
544
- auto db = thrust::make_transform_iterator (
545
- thrust::counting_iterator<std::size_t >{0 },
546
- makeViewIteratorFromIndexCreator (dstView));
560
+ auto db = viewIteratorAt (dstView, 0 );
547
561
Stopwatch stopwatch;
548
562
if constexpr (usePSTL)
549
563
std::copy (exec, b, e, db);
@@ -559,9 +573,7 @@ void run(std::ostream& plotFile)
559
573
560
574
{
561
575
auto dstView = llama::allocView (mapping, thrustDeviceAlloc);
562
- auto db = thrust::make_transform_iterator (
563
- thrust::counting_iterator<std::size_t >{0 },
564
- makeViewIteratorFromIndexCreator (dstView));
576
+ auto db = viewIteratorAt (dstView, 0 );
565
577
Stopwatch stopwatch;
566
578
if constexpr (usePSTL)
567
579
std::copy_if (exec, b, e, db, Predicate{});
0 commit comments