Skip to content

Commit 3653ff6

Browse files
fix
1 parent 431b044 commit 3653ff6

File tree

1 file changed

+27
-15
lines changed

1 file changed

+27
-15
lines changed

examples/thrust/thrust.cu

+27-15
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,29 @@ auto thrustDeviceAlloc = [](auto alignment, std::size_t size)
321321
return p;
322322
};
323323

324+
template <typename View>
325+
struct ViewIteratorAt
326+
{
327+
View view;
328+
329+
LLAMA_FN_HOST_ACC_INLINE auto operator()(std::size_t i)
330+
{
331+
return *(view.begin() + i);
332+
}
333+
};
334+
335+
template<typename View>
336+
auto viewIteratorAt(View& view, std::size_t index)
337+
{
338+
ViewIteratorAt<View> t{view};
339+
using ViewTransformIterator = thrust::transform_iterator<
340+
decltype(t),
341+
thrust::counting_iterator<std::size_t>,
342+
typename View::iterator::reference,
343+
typename View::iterator::value_type>;
344+
return ViewTransformIterator{thrust::counting_iterator<std::size_t>{index}, t};
345+
}
346+
324347
template<int Mapping>
325348
void run(std::ostream& plotFile)
326349
{
@@ -374,15 +397,8 @@ void run(std::ostream& plotFile)
374397
std::cout << mappingName << '\n';
375398

376399
auto view = llama::allocView(mapping, thrustDeviceAlloc);
377-
378-
auto makeViewIteratorFromIndexCreator = [](decltype(view) view)
379-
{ return [view] __host__ __device__(std::size_t i) mutable { return *(view.begin() + i); }; };
380-
auto b = thrust::make_transform_iterator(
381-
thrust::counting_iterator<std::size_t>{0},
382-
makeViewIteratorFromIndexCreator(view));
383-
auto e = thrust::make_transform_iterator(
384-
thrust::counting_iterator<std::size_t>{N},
385-
makeViewIteratorFromIndexCreator(view));
400+
auto b = viewIteratorAt(view, 0);
401+
auto e = viewIteratorAt(view, N);
386402
// auto b = view.begin();
387403
// auto e = view.end();
388404

@@ -541,9 +557,7 @@ void run(std::ostream& plotFile)
541557

542558
{
543559
auto dstView = llama::allocView(mapping, thrustDeviceAlloc);
544-
auto db = thrust::make_transform_iterator(
545-
thrust::counting_iterator<std::size_t>{0},
546-
makeViewIteratorFromIndexCreator(dstView));
560+
auto db = viewIteratorAt(dstView, 0);
547561
Stopwatch stopwatch;
548562
if constexpr(usePSTL)
549563
std::copy(exec, b, e, db);
@@ -559,9 +573,7 @@ void run(std::ostream& plotFile)
559573

560574
{
561575
auto dstView = llama::allocView(mapping, thrustDeviceAlloc);
562-
auto db = thrust::make_transform_iterator(
563-
thrust::counting_iterator<std::size_t>{0},
564-
makeViewIteratorFromIndexCreator(dstView));
576+
auto db = viewIteratorAt(dstView, 0);
565577
Stopwatch stopwatch;
566578
if constexpr(usePSTL)
567579
std::copy_if(exec, b, e, db, Predicate{});

0 commit comments

Comments
 (0)