Skip to content

Commit 7b5b49a

Browse files
create LLAMA iterators on the fly and revert hacks
1 parent 7761702 commit 7b5b49a

File tree

4 files changed

+65
-41
lines changed

4 files changed

+65
-41
lines changed

Diff for: examples/thrust/thrust.cu

+51-26
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,27 @@ auto thrustDeviceAlloc = [](auto alignment, std::size_t size)
321321
return p;
322322
};
323323

324+
template<typename View>
325+
struct IndexToViewIterator
326+
{
327+
View view;
328+
LLAMA_FN_HOST_ACC_INLINE auto operator()(std::size_t i)
329+
{
330+
return *(view.begin() + i);
331+
}
332+
};
333+
334+
template<typename View>
335+
auto make_view_it(View view, std::size_t i)
336+
{
337+
auto ci = thrust::counting_iterator<std::size_t>{0};
338+
return thrust::transform_iterator<
339+
IndexToViewIterator<View>,
340+
decltype(ci),
341+
typename View::iterator::reference,
342+
typename View::iterator::value_type>{ci, IndexToViewIterator<View>{std::move(view)}};
343+
}
344+
324345
template<int Mapping>
325346
void run(std::ostream& plotFile)
326347
{
@@ -375,8 +396,16 @@ void run(std::ostream& plotFile)
375396

376397
auto view = llama::allocView(mapping, thrustDeviceAlloc);
377398

399+
auto b = make_view_it(view, 0);
400+
auto e = make_view_it(view, N);
401+
// auto b = view.begin();
402+
// auto e = view.end();
403+
404+
auto r = (*b);
405+
r(tag::eventId{}) = 0;
406+
378407
// touch memory once before running benchmarks
379-
thrust::fill(thrust::device, view.begin(), view.end(), 0);
408+
thrust::fill(thrust::device, b, e, 0);
380409
syncWithCuda();
381410

382411
//#if THRUST_DEVICE_SYSTEM == THRUST_DEVICE_SYSTEM_CUDA
@@ -427,7 +456,7 @@ void run(std::ostream& plotFile)
427456
}
428457
else
429458
{
430-
thrust::tabulate(thrust::device, view.begin(), view.end(), InitOne{});
459+
thrust::tabulate(thrust::device, b, e, InitOne{});
431460
syncWithCuda();
432461
}
433462
tabulateTotal += stopwatch.printAndReset("tabulate", '\t');
@@ -453,7 +482,7 @@ void run(std::ostream& plotFile)
453482
{
454483
Stopwatch stopwatch;
455484
if constexpr(usePSTL)
456-
std::for_each(exec, view.begin(), view.end(), NormalizeVel{});
485+
std::for_each(exec, b, e, NormalizeVel{});
457486
else
458487
{
459488
thrust::for_each(
@@ -471,10 +500,10 @@ void run(std::ostream& plotFile)
471500
thrust::device_vector<MassType> dst(N);
472501
Stopwatch stopwatch;
473502
if constexpr(usePSTL)
474-
std::transform(exec, view.begin(), view.end(), dst.begin(), GetMass{});
503+
std::transform(exec, b, e, dst.begin(), GetMass{});
475504
else
476505
{
477-
thrust::transform(thrust::device, view.begin(), view.end(), dst.begin(), GetMass{});
506+
thrust::transform(thrust::device, b, e, dst.begin(), GetMass{});
478507
syncWithCuda();
479508
}
480509
transformTotal += stopwatch.printAndReset("transform", '\t');
@@ -489,8 +518,8 @@ void run(std::ostream& plotFile)
489518
if constexpr(usePSTL)
490519
std::transform_exclusive_scan(
491520
exec,
492-
view.begin(),
493-
view.end(),
521+
b,
522+
e,
494523
scan_result.begin(),
495524
std::uint32_t{0},
496525
std::plus<>{},
@@ -499,8 +528,8 @@ void run(std::ostream& plotFile)
499528
{
500529
thrust::transform_exclusive_scan(
501530
thrust::device,
502-
view.begin(),
503-
view.end(),
531+
b,
532+
e,
504533
scan_result.begin(),
505534
Predicate{},
506535
std::uint32_t{0},
@@ -516,29 +545,24 @@ void run(std::ostream& plotFile)
516545
{
517546
Stopwatch stopwatch;
518547
if constexpr(usePSTL)
519-
sink = std::transform_reduce(exec, view.begin(), view.end(), MassType{0}, std::plus<>{}, GetMass{});
548+
sink = std::transform_reduce(exec, b, e, MassType{0}, std::plus<>{}, GetMass{});
520549
else
521550
{
522-
sink = thrust::transform_reduce(
523-
thrust::device,
524-
view.begin(),
525-
view.end(),
526-
GetMass{},
527-
MassType{0},
528-
thrust::plus<>{});
551+
sink = thrust::transform_reduce(thrust::device, b, e, GetMass{}, MassType{0}, thrust::plus<>{});
529552
syncWithCuda();
530553
}
531554
transformReduceTotal += stopwatch.printAndReset("transform_reduce", '\t');
532555
}
533556

534557
{
535558
auto dstView = llama::allocView(mapping, thrustDeviceAlloc);
559+
auto db = make_view_it(dstView, 0);
536560
Stopwatch stopwatch;
537561
if constexpr(usePSTL)
538-
std::copy(exec, view.begin(), view.end(), dstView.begin());
562+
std::copy(exec, b, e, db);
539563
else
540564
{
541-
thrust::copy(thrust::device, view.begin(), view.end(), dstView.begin());
565+
thrust::copy(thrust::device, b, e, db);
542566
syncWithCuda();
543567
}
544568
copyTotal += stopwatch.printAndReset("copy", '\t');
@@ -548,12 +572,13 @@ void run(std::ostream& plotFile)
548572

549573
{
550574
auto dstView = llama::allocView(mapping, thrustDeviceAlloc);
575+
auto db = make_view_it(dstView, 0);
551576
Stopwatch stopwatch;
552577
if constexpr(usePSTL)
553-
std::copy_if(exec, view.begin(), view.end(), dstView.begin(), Predicate{});
578+
std::copy_if(exec, b, e, db, Predicate{});
554579
else
555580
{
556-
thrust::copy_if(thrust::device, view.begin(), view.end(), dstView.begin(), Predicate{});
581+
thrust::copy_if(thrust::device, b, e, db, Predicate{});
557582
syncWithCuda();
558583
}
559584
copyIfTotal += stopwatch.printAndReset("copy_if", '\t');
@@ -564,10 +589,10 @@ void run(std::ostream& plotFile)
564589
{
565590
Stopwatch stopwatch;
566591
if constexpr(usePSTL)
567-
std::remove_if(exec, view.begin(), view.end(), Predicate{});
592+
std::remove_if(exec, b, e, Predicate{});
568593
else
569594
{
570-
thrust::remove_if(thrust::device, view.begin(), view.end(), Predicate{});
595+
thrust::remove_if(thrust::device, b, e, Predicate{});
571596
syncWithCuda();
572597
}
573598
removeIfTotal += stopwatch.printAndReset("remove_if", '\t');
@@ -576,14 +601,14 @@ void run(std::ostream& plotFile)
576601
//{
577602
// Stopwatch stopwatch;
578603
// if constexpr(usePSTL)
579-
// std::sort(std::execution::par, view.begin(), view.end(), Less{});
604+
// std::sort(std::execution::par, b, e, Less{});
580605
// else
581606
// {
582-
// thrust::sort(thrust::device, view.begin(), view.end(), Less{});
607+
// thrust::sort(thrust::device, b, e, Less{});
583608
// syncWithCuda();
584609
// }
585610
// sortTotal += stopwatch.printAndReset("sort", '\t');
586-
// if(!thrust::is_sorted(thrust::device, view.begin(), view.end(), Less{}))
611+
// if(!thrust::is_sorted(thrust::device, b, e, Less{}))
587612
// std::cerr << "VALIDATION FAILED\n";
588613
//}
589614

Diff for: include/llama/ArrayIndexRange.hpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,11 @@ namespace llama
120120

121121
current[0] = static_cast<difference_type>(current[0]) + n;
122122
// current is either within bounds or at the end ([last + 1, 0, 0, ..., 0])
123-
//assert(
124-
// (current[0] < extents[0]
125-
// || (current[0] == extents[0]
126-
// && std::all_of(std::begin(current) + 1, std::end(current), [](auto c) { return c == 0; })))
127-
// && "Iterator was moved past the end");
123+
assert(
124+
(current[0] < extents[0]
125+
|| (current[0] == extents[0]
126+
&& std::all_of(std::begin(current) + 1, std::end(current), [](auto c) { return c == 0; })))
127+
&& "Iterator was moved past the end");
128128

129129
return *this;
130130
}

Diff for: include/llama/RecordRef.hpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -381,8 +381,7 @@ namespace llama
381381
using ArrayIndex = typename View::Mapping::ArrayIndex;
382382
using RecordDim = typename View::Mapping::RecordDim;
383383

384-
// std::conditional_t<OwnView, View, View&> view;
385-
View view;
384+
std::conditional_t<OwnView, View, View&> view;
386385

387386
public:
388387
/// Subtree of the record dimension of View starting at BoundRecordCoord. If BoundRecordCoord is

Diff for: include/llama/View.hpp

+8-8
Original file line numberDiff line numberDiff line change
@@ -229,9 +229,9 @@ namespace llama
229229

230230
constexpr Iterator() = default;
231231

232-
LLAMA_FN_HOST_ACC_INLINE constexpr Iterator(ArrayIndexIterator arrayIndex, View view)
232+
LLAMA_FN_HOST_ACC_INLINE constexpr Iterator(ArrayIndexIterator arrayIndex, View* view)
233233
: arrayIndex(arrayIndex)
234-
, view(std::move(view))
234+
, view(view)
235235
{
236236
}
237237

@@ -268,7 +268,7 @@ namespace llama
268268
LLAMA_FN_HOST_ACC_INLINE
269269
constexpr auto operator*() const -> reference
270270
{
271-
return const_cast<View&>(view)(*arrayIndex);
271+
return (*view)(*arrayIndex);
272272
}
273273

274274
LLAMA_FN_HOST_ACC_INLINE
@@ -363,7 +363,7 @@ namespace llama
363363
}
364364

365365
ArrayIndexIterator arrayIndex;
366-
View view;
366+
View* view;
367367
};
368368

369369
/// Using a mapping, maps the given array index and record coordinate to a memory reference onto the given blobs.
@@ -559,25 +559,25 @@ namespace llama
559559
LLAMA_FN_HOST_ACC_INLINE
560560
auto begin() -> iterator
561561
{
562-
return {ArrayIndexRange<ArrayExtents>{mapping().extents()}.begin(), *this};
562+
return {ArrayIndexRange<ArrayExtents>{mapping().extents()}.begin(), this};
563563
}
564564

565565
LLAMA_FN_HOST_ACC_INLINE
566566
auto begin() const -> const_iterator
567567
{
568-
return {ArrayIndexRange<ArrayExtents>{mapping().extents()}.begin(), *this};
568+
return {ArrayIndexRange<ArrayExtents>{mapping().extents()}.begin(), this};
569569
}
570570

571571
LLAMA_FN_HOST_ACC_INLINE
572572
auto end() -> iterator
573573
{
574-
return {ArrayIndexRange<ArrayExtents>{mapping().extents()}.end(), *this};
574+
return {ArrayIndexRange<ArrayExtents>{mapping().extents()}.end(), this};
575575
}
576576

577577
LLAMA_FN_HOST_ACC_INLINE
578578
auto end() const -> const_iterator
579579
{
580-
return {ArrayIndexRange<ArrayExtents>{mapping().extents()}.end(), *this};
580+
return {ArrayIndexRange<ArrayExtents>{mapping().extents()}.end(), this};
581581
}
582582

583583
Array<BlobType, Mapping::blobCount> storageBlobs;

0 commit comments

Comments
 (0)