Skip to content

Commit 54ff97a

Browse files
create LLAMA iterators on the fly and revert hacks
1 parent 11a9ecf commit 54ff97a

File tree

4 files changed

+65
-41
lines changed

4 files changed

+65
-41
lines changed

examples/thrust/thrust.cu

+51-26
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,27 @@ auto thrustDeviceAlloc = [](auto alignment, std::size_t size)
321321
return p;
322322
};
323323

324+
template<typename View>
325+
struct IndexToViewIterator
326+
{
327+
View view;
328+
LLAMA_FN_HOST_ACC_INLINE auto operator()(std::size_t i)
329+
{
330+
return *(view.begin() + i);
331+
}
332+
};
333+
334+
template<typename View>
335+
auto make_view_it(View view, std::size_t i)
336+
{
337+
auto ci = thrust::counting_iterator<std::size_t>{0};
338+
return thrust::transform_iterator<
339+
IndexToViewIterator<View>,
340+
decltype(ci),
341+
typename View::iterator::reference,
342+
typename View::iterator::value_type>{ci, IndexToViewIterator<View>{std::move(view)}};
343+
}
344+
324345
template<int Mapping>
325346
void run(std::ostream& plotFile)
326347
{
@@ -375,8 +396,16 @@ void run(std::ostream& plotFile)
375396

376397
auto view = llama::allocView(mapping, thrustDeviceAlloc);
377398

399+
auto b = make_view_it(view, 0);
400+
auto e = make_view_it(view, N);
401+
// auto b = view.begin();
402+
// auto e = view.end();
403+
404+
auto r = (*b);
405+
r(tag::eventId{}) = 0;
406+
378407
// touch memory once before running benchmarks
379-
thrust::fill(thrust::device, view.begin(), view.end(), 0);
408+
thrust::fill(thrust::device, b, e, 0);
380409
syncWithCuda();
381410

382411
//#if THRUST_DEVICE_SYSTEM == THRUST_DEVICE_SYSTEM_CUDA
@@ -427,7 +456,7 @@ void run(std::ostream& plotFile)
427456
}
428457
else
429458
{
430-
thrust::tabulate(thrust::device, view.begin(), view.end(), InitOne{});
459+
thrust::tabulate(thrust::device, b, e, InitOne{});
431460
syncWithCuda();
432461
}
433462
tabulateTotal += stopwatch.printAndReset("tabulate", '\t');
@@ -453,7 +482,7 @@ void run(std::ostream& plotFile)
453482
{
454483
Stopwatch stopwatch;
455484
if constexpr(usePSTL)
456-
std::for_each(exec, view.begin(), view.end(), NormalizeVel{});
485+
std::for_each(exec, b, e, NormalizeVel{});
457486
else
458487
{
459488
thrust::for_each(
@@ -471,10 +500,10 @@ void run(std::ostream& plotFile)
471500
thrust::device_vector<MassType> dst(N);
472501
Stopwatch stopwatch;
473502
if constexpr(usePSTL)
474-
std::transform(exec, view.begin(), view.end(), dst.begin(), GetMass{});
503+
std::transform(exec, b, e, dst.begin(), GetMass{});
475504
else
476505
{
477-
thrust::transform(thrust::device, view.begin(), view.end(), dst.begin(), GetMass{});
506+
thrust::transform(thrust::device, b, e, dst.begin(), GetMass{});
478507
syncWithCuda();
479508
}
480509
transformTotal += stopwatch.printAndReset("transform", '\t');
@@ -489,8 +518,8 @@ void run(std::ostream& plotFile)
489518
if constexpr(usePSTL)
490519
std::transform_exclusive_scan(
491520
exec,
492-
view.begin(),
493-
view.end(),
521+
b,
522+
e,
494523
scan_result.begin(),
495524
std::uint32_t{0},
496525
std::plus<>{},
@@ -499,8 +528,8 @@ void run(std::ostream& plotFile)
499528
{
500529
thrust::transform_exclusive_scan(
501530
thrust::device,
502-
view.begin(),
503-
view.end(),
531+
b,
532+
e,
504533
scan_result.begin(),
505534
Predicate{},
506535
std::uint32_t{0},
@@ -516,29 +545,24 @@ void run(std::ostream& plotFile)
516545
{
517546
Stopwatch stopwatch;
518547
if constexpr(usePSTL)
519-
sink = std::transform_reduce(exec, view.begin(), view.end(), MassType{0}, std::plus<>{}, GetMass{});
548+
sink = std::transform_reduce(exec, b, e, MassType{0}, std::plus<>{}, GetMass{});
520549
else
521550
{
522-
sink = thrust::transform_reduce(
523-
thrust::device,
524-
view.begin(),
525-
view.end(),
526-
GetMass{},
527-
MassType{0},
528-
thrust::plus<>{});
551+
sink = thrust::transform_reduce(thrust::device, b, e, GetMass{}, MassType{0}, thrust::plus<>{});
529552
syncWithCuda();
530553
}
531554
transformReduceTotal += stopwatch.printAndReset("transform_reduce", '\t');
532555
}
533556

534557
{
535558
auto dstView = llama::allocView(mapping, thrustDeviceAlloc);
559+
auto db = make_view_it(dstView, 0);
536560
Stopwatch stopwatch;
537561
if constexpr(usePSTL)
538-
std::copy(exec, view.begin(), view.end(), dstView.begin());
562+
std::copy(exec, b, e, db);
539563
else
540564
{
541-
thrust::copy(thrust::device, view.begin(), view.end(), dstView.begin());
565+
thrust::copy(thrust::device, b, e, db);
542566
syncWithCuda();
543567
}
544568
copyTotal += stopwatch.printAndReset("copy", '\t');
@@ -548,12 +572,13 @@ void run(std::ostream& plotFile)
548572

549573
{
550574
auto dstView = llama::allocView(mapping, thrustDeviceAlloc);
575+
auto db = make_view_it(dstView, 0);
551576
Stopwatch stopwatch;
552577
if constexpr(usePSTL)
553-
std::copy_if(exec, view.begin(), view.end(), dstView.begin(), Predicate{});
578+
std::copy_if(exec, b, e, db, Predicate{});
554579
else
555580
{
556-
thrust::copy_if(thrust::device, view.begin(), view.end(), dstView.begin(), Predicate{});
581+
thrust::copy_if(thrust::device, b, e, db, Predicate{});
557582
syncWithCuda();
558583
}
559584
copyIfTotal += stopwatch.printAndReset("copy_if", '\t');
@@ -564,10 +589,10 @@ void run(std::ostream& plotFile)
564589
{
565590
Stopwatch stopwatch;
566591
if constexpr(usePSTL)
567-
std::remove_if(exec, view.begin(), view.end(), Predicate{});
592+
std::remove_if(exec, b, e, Predicate{});
568593
else
569594
{
570-
thrust::remove_if(thrust::device, view.begin(), view.end(), Predicate{});
595+
thrust::remove_if(thrust::device, b, e, Predicate{});
571596
syncWithCuda();
572597
}
573598
removeIfTotal += stopwatch.printAndReset("remove_if", '\t');
@@ -576,14 +601,14 @@ void run(std::ostream& plotFile)
576601
//{
577602
// Stopwatch stopwatch;
578603
// if constexpr(usePSTL)
579-
// std::sort(std::execution::par, view.begin(), view.end(), Less{});
604+
// std::sort(std::execution::par, b, e, Less{});
580605
// else
581606
// {
582-
// thrust::sort(thrust::device, view.begin(), view.end(), Less{});
607+
// thrust::sort(thrust::device, b, e, Less{});
583608
// syncWithCuda();
584609
// }
585610
// sortTotal += stopwatch.printAndReset("sort", '\t');
586-
// if(!thrust::is_sorted(thrust::device, view.begin(), view.end(), Less{}))
611+
// if(!thrust::is_sorted(thrust::device, b, e, Less{}))
587612
// std::cerr << "VALIDATION FAILED\n";
588613
//}
589614

include/llama/ArrayIndexRange.hpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -119,11 +119,11 @@ namespace llama
119119

120120
current[0] = static_cast<difference_type>(current[0]) + n;
121121
// current is either within bounds or at the end ([last + 1, 0, 0, ..., 0])
122-
//assert(
123-
// (current[0] < extents[0]
124-
// || (current[0] == extents[0]
125-
// && std::all_of(std::begin(current) + 1, std::end(current), [](auto c) { return c == 0; })))
126-
// && "Iterator was moved past the end");
122+
assert(
123+
(current[0] < extents[0]
124+
|| (current[0] == extents[0]
125+
&& std::all_of(std::begin(current) + 1, std::end(current), [](auto c) { return c == 0; })))
126+
&& "Iterator was moved past the end");
127127

128128
return *this;
129129
}

include/llama/View.hpp

+8-8
Original file line numberDiff line numberDiff line change
@@ -149,9 +149,9 @@ namespace llama
149149

150150
constexpr Iterator() = default;
151151

152-
LLAMA_FN_HOST_ACC_INLINE constexpr Iterator(ArrayIndexIterator arrayIndex, View view)
152+
LLAMA_FN_HOST_ACC_INLINE constexpr Iterator(ArrayIndexIterator arrayIndex, View* view)
153153
: arrayIndex(arrayIndex)
154-
, view(std::move(view))
154+
, view(view)
155155
{
156156
}
157157

@@ -188,7 +188,7 @@ namespace llama
188188
LLAMA_FN_HOST_ACC_INLINE
189189
constexpr auto operator*() const -> reference
190190
{
191-
return const_cast<View&>(view)(*arrayIndex);
191+
return (*view)(*arrayIndex);
192192
}
193193

194194
LLAMA_FN_HOST_ACC_INLINE
@@ -283,7 +283,7 @@ namespace llama
283283
}
284284

285285
ArrayIndexIterator arrayIndex;
286-
View view;
286+
View* view;
287287
};
288288

289289
/// Using a mapping, maps the given array index and record coordinate to a memory reference onto the given blobs.
@@ -462,25 +462,25 @@ namespace llama
462462
LLAMA_FN_HOST_ACC_INLINE
463463
auto begin() -> iterator
464464
{
465-
return {ArrayIndexRange<ArrayExtents>{mapping().extents()}.begin(), *this};
465+
return {ArrayIndexRange<ArrayExtents>{mapping().extents()}.begin(), this};
466466
}
467467

468468
LLAMA_FN_HOST_ACC_INLINE
469469
auto begin() const -> const_iterator
470470
{
471-
return {ArrayIndexRange<ArrayExtents>{mapping().extents()}.begin(), *this};
471+
return {ArrayIndexRange<ArrayExtents>{mapping().extents()}.begin(), this};
472472
}
473473

474474
LLAMA_FN_HOST_ACC_INLINE
475475
auto end() -> iterator
476476
{
477-
return {ArrayIndexRange<ArrayExtents>{mapping().extents()}.end(), *this};
477+
return {ArrayIndexRange<ArrayExtents>{mapping().extents()}.end(), this};
478478
}
479479

480480
LLAMA_FN_HOST_ACC_INLINE
481481
auto end() const -> const_iterator
482482
{
483-
return {ArrayIndexRange<ArrayExtents>{mapping().extents()}.end(), *this};
483+
return {ArrayIndexRange<ArrayExtents>{mapping().extents()}.end(), this};
484484
}
485485

486486
Array<BlobType, Mapping::blobCount> storageBlobs;

include/llama/VirtualRecord.hpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -349,8 +349,7 @@ namespace llama
349349
using ArrayIndex = typename View::Mapping::ArrayIndex;
350350
using RecordDim = typename View::Mapping::RecordDim;
351351

352-
// std::conditional_t<OwnView, View, View&> view;
353-
View view;
352+
std::conditional_t<OwnView, View, View&> view;
354353

355354
public:
356355
/// Subtree of the record dimension of View starting at BoundRecordCoord. If BoundRecordCoord is

0 commit comments

Comments
 (0)