@@ -321,6 +321,27 @@ auto thrustDeviceAlloc = [](auto alignment, std::size_t size)
321
321
return p;
322
322
};
323
323
324
+ template <typename View>
325
+ struct IndexToViewIterator
326
+ {
327
+ View view;
328
+ LLAMA_FN_HOST_ACC_INLINE auto operator ()(std::size_t i)
329
+ {
330
+ return *(view.begin () + i);
331
+ }
332
+ };
333
+
334
+ template <typename View>
335
+ auto make_view_it (View view, std::size_t i)
336
+ {
337
+ auto ci = thrust::counting_iterator<std::size_t >{0 };
338
+ return thrust::transform_iterator<
339
+ IndexToViewIterator<View>,
340
+ decltype (ci),
341
+ typename View::iterator::reference,
342
+ typename View::iterator::value_type>{ci, IndexToViewIterator<View>{std::move (view)}};
343
+ }
344
+
324
345
template <int Mapping>
325
346
void run (std::ostream& plotFile)
326
347
{
@@ -375,8 +396,16 @@ void run(std::ostream& plotFile)
375
396
376
397
auto view = llama::allocView (mapping, thrustDeviceAlloc);
377
398
399
+ auto b = make_view_it (view, 0 );
400
+ auto e = make_view_it (view, N);
401
+ // auto b = view.begin();
402
+ // auto e = view.end();
403
+
404
+ auto r = (*b);
405
+ r (tag::eventId{}) = 0 ;
406
+
378
407
// touch memory once before running benchmarks
379
- thrust::fill (thrust::device, view. begin (), view. end () , 0 );
408
+ thrust::fill (thrust::device, b, e , 0 );
380
409
syncWithCuda ();
381
410
382
411
// #if THRUST_DEVICE_SYSTEM == THRUST_DEVICE_SYSTEM_CUDA
@@ -427,7 +456,7 @@ void run(std::ostream& plotFile)
427
456
}
428
457
else
429
458
{
430
- thrust::tabulate (thrust::device, view. begin (), view. end () , InitOne{});
459
+ thrust::tabulate (thrust::device, b, e , InitOne{});
431
460
syncWithCuda ();
432
461
}
433
462
tabulateTotal += stopwatch.printAndReset (" tabulate" , ' \t ' );
@@ -453,7 +482,7 @@ void run(std::ostream& plotFile)
453
482
{
454
483
Stopwatch stopwatch;
455
484
if constexpr (usePSTL)
456
- std::for_each (exec, view. begin (), view. end () , NormalizeVel{});
485
+ std::for_each (exec, b, e , NormalizeVel{});
457
486
else
458
487
{
459
488
thrust::for_each (
@@ -471,10 +500,10 @@ void run(std::ostream& plotFile)
471
500
thrust::device_vector<MassType> dst (N);
472
501
Stopwatch stopwatch;
473
502
if constexpr (usePSTL)
474
- std::transform (exec, view. begin (), view. end () , dst.begin (), GetMass{});
503
+ std::transform (exec, b, e , dst.begin (), GetMass{});
475
504
else
476
505
{
477
- thrust::transform (thrust::device, view. begin (), view. end () , dst.begin (), GetMass{});
506
+ thrust::transform (thrust::device, b, e , dst.begin (), GetMass{});
478
507
syncWithCuda ();
479
508
}
480
509
transformTotal += stopwatch.printAndReset (" transform" , ' \t ' );
@@ -489,8 +518,8 @@ void run(std::ostream& plotFile)
489
518
if constexpr (usePSTL)
490
519
std::transform_exclusive_scan (
491
520
exec,
492
- view. begin () ,
493
- view. end () ,
521
+ b ,
522
+ e ,
494
523
scan_result.begin (),
495
524
std::uint32_t {0 },
496
525
std::plus<>{},
@@ -499,8 +528,8 @@ void run(std::ostream& plotFile)
499
528
{
500
529
thrust::transform_exclusive_scan (
501
530
thrust::device,
502
- view. begin () ,
503
- view. end () ,
531
+ b ,
532
+ e ,
504
533
scan_result.begin (),
505
534
Predicate{},
506
535
std::uint32_t {0 },
@@ -516,29 +545,24 @@ void run(std::ostream& plotFile)
516
545
{
517
546
Stopwatch stopwatch;
518
547
if constexpr (usePSTL)
519
- sink = std::transform_reduce (exec, view. begin (), view. end () , MassType{0 }, std::plus<>{}, GetMass{});
548
+ sink = std::transform_reduce (exec, b, e , MassType{0 }, std::plus<>{}, GetMass{});
520
549
else
521
550
{
522
- sink = thrust::transform_reduce (
523
- thrust::device,
524
- view.begin (),
525
- view.end (),
526
- GetMass{},
527
- MassType{0 },
528
- thrust::plus<>{});
551
+ sink = thrust::transform_reduce (thrust::device, b, e, GetMass{}, MassType{0 }, thrust::plus<>{});
529
552
syncWithCuda ();
530
553
}
531
554
transformReduceTotal += stopwatch.printAndReset (" transform_reduce" , ' \t ' );
532
555
}
533
556
534
557
{
535
558
auto dstView = llama::allocView (mapping, thrustDeviceAlloc);
559
+ auto db = make_view_it (dstView, 0 );
536
560
Stopwatch stopwatch;
537
561
if constexpr (usePSTL)
538
- std::copy (exec, view. begin (), view. end (), dstView. begin () );
562
+ std::copy (exec, b, e, db );
539
563
else
540
564
{
541
- thrust::copy (thrust::device, view. begin (), view. end (), dstView. begin () );
565
+ thrust::copy (thrust::device, b, e, db );
542
566
syncWithCuda ();
543
567
}
544
568
copyTotal += stopwatch.printAndReset (" copy" , ' \t ' );
@@ -548,12 +572,13 @@ void run(std::ostream& plotFile)
548
572
549
573
{
550
574
auto dstView = llama::allocView (mapping, thrustDeviceAlloc);
575
+ auto db = make_view_it (dstView, 0 );
551
576
Stopwatch stopwatch;
552
577
if constexpr (usePSTL)
553
- std::copy_if (exec, view. begin (), view. end (), dstView. begin () , Predicate{});
578
+ std::copy_if (exec, b, e, db , Predicate{});
554
579
else
555
580
{
556
- thrust::copy_if (thrust::device, view. begin (), view. end (), dstView. begin () , Predicate{});
581
+ thrust::copy_if (thrust::device, b, e, db , Predicate{});
557
582
syncWithCuda ();
558
583
}
559
584
copyIfTotal += stopwatch.printAndReset (" copy_if" , ' \t ' );
@@ -564,10 +589,10 @@ void run(std::ostream& plotFile)
564
589
{
565
590
Stopwatch stopwatch;
566
591
if constexpr (usePSTL)
567
- std::remove_if (exec, view. begin (), view. end () , Predicate{});
592
+ std::remove_if (exec, b, e , Predicate{});
568
593
else
569
594
{
570
- thrust::remove_if (thrust::device, view. begin (), view. end () , Predicate{});
595
+ thrust::remove_if (thrust::device, b, e , Predicate{});
571
596
syncWithCuda ();
572
597
}
573
598
removeIfTotal += stopwatch.printAndReset (" remove_if" , ' \t ' );
@@ -576,14 +601,14 @@ void run(std::ostream& plotFile)
576
601
// {
577
602
// Stopwatch stopwatch;
578
603
// if constexpr(usePSTL)
579
- // std::sort(std::execution::par, view.begin(), view.end() , Less{});
604
+ // std::sort(std::execution::par, b, e , Less{});
580
605
// else
581
606
// {
582
- // thrust::sort(thrust::device, view.begin(), view.end() , Less{});
607
+ // thrust::sort(thrust::device, b, e , Less{});
583
608
// syncWithCuda();
584
609
// }
585
610
// sortTotal += stopwatch.printAndReset("sort", '\t');
586
- // if(!thrust::is_sorted(thrust::device, view.begin(), view.end() , Less{}))
611
+ // if(!thrust::is_sorted(thrust::device, b, e , Less{}))
587
612
// std::cerr << "VALIDATION FAILED\n";
588
613
// }
589
614
0 commit comments