Skip to content

Commit f090f0a

Browse files
Update alpaka nbody to C++20 and simplify
1 parent bd95c53 commit f090f0a

File tree

2 files changed

+26
-34
lines changed

2 files changed

+26
-34
lines changed

examples/alpaka/nbody/CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ if (NOT TARGET llama::llama)
1515
endif()
1616
find_package(alpaka 1.0 REQUIRED)
1717
alpaka_add_executable(${PROJECT_NAME} nbody.cpp ../../common/Stopwatch.hpp)
18-
target_compile_features(${PROJECT_NAME} PRIVATE cxx_std_17)
18+
target_compile_features(${PROJECT_NAME} PRIVATE cxx_std_20)
1919
target_link_libraries(${PROJECT_NAME} PRIVATE llama::llama fmt::fmt alpaka::alpaka xsimd)
2020

2121
if (MSVC)

examples/alpaka/nbody/nbody.cpp

+25-33
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ constexpr auto runUpdate = true; // run update step. Useful to disable for bench
4343
#endif
4444

4545
#if ANY_CPU_ENABLED
46-
constexpr auto elementsPerThread = xsimd::batch<float>::size;
4746
constexpr auto threadsPerBlock = 1;
4847
constexpr auto sharedElementsPerBlock = 1;
48+
constexpr auto elementsPerThread = xsimd::batch<float>::size;
4949
constexpr auto aosoaLanes = elementsPerThread;
5050
#elif ANY_GPU_ENABLED
5151
constexpr auto threadsPerBlock = 256;
@@ -101,9 +101,6 @@ struct llama::SimdTraits<Batch, std::enable_if_t<xsimd::is_batch<Batch>::value>>
101101
}
102102
};
103103

104-
template<typename T>
105-
using MakeBatch = xsimd::batch<T>;
106-
107104
template<typename T, std::size_t N>
108105
struct MakeSizedBatchImpl
109106
{
@@ -166,66 +163,60 @@ LLAMA_FN_HOST_ACC_INLINE void pPInteraction(const Acc& acc, ParticleRefI& pis, P
166163
pis(tag::Vel{}) += dist * sts;
167164
}
168165

169-
template<int Elems, typename QuotedSMMapping>
166+
template<int ThreadsPerBlock, int SharedElementsPerBlock, int ElementsPerThread, typename QuotedSMMapping>
170167
struct UpdateKernel
171168
{
172-
template<typename Acc, typename View>
173-
ALPAKA_FN_HOST_ACC void operator()(const Acc& acc, View particles) const
169+
ALPAKA_FN_HOST_ACC void operator()(const auto& acc, auto particles) const
174170
{
175171
auto sharedView = [&]
176172
{
177173
// if there is only 1 shared element per block, use just a variable (in registers) instead of shared memory
178-
if constexpr(sharedElementsPerBlock == 1)
174+
if constexpr(SharedElementsPerBlock == 1)
179175
{
180176
using Mapping = llama::mapping::MinAlignedOne<llama::ArrayExtents<int, 1>, SharedMemoryParticle>;
181177
return allocViewUninitialized(Mapping{}, llama::bloballoc::Array<Mapping{}.blobSize(0)>{});
182178
}
183179
else
184180
{
185-
using ArrayExtents = llama::ArrayExtents<int, sharedElementsPerBlock>;
186-
using Mapping = typename QuotedSMMapping::template fn<ArrayExtents, SharedMemoryParticle>;
187-
constexpr auto sharedMapping = Mapping{};
188-
189-
llama::Array<std::byte*, Mapping::blobCount> sharedMems{};
190-
boost::mp11::mp_for_each<boost::mp11::mp_iota_c<Mapping::blobCount>>(
191-
[&](auto i)
192-
{
193-
auto& sharedMem = alpaka::declareSharedVar<std::byte[sharedMapping.blobSize(i)], i>(acc);
194-
sharedMems[i] = &sharedMem[0];
195-
});
196-
return llama::View{sharedMapping, sharedMems};
181+
using Mapping = typename QuotedSMMapping::
182+
template fn<llama::ArrayExtents<int, SharedElementsPerBlock>, SharedMemoryParticle>;
183+
return [&]<std::size_t... Is>(std::index_sequence<Is...>)
184+
{
185+
return llama::View{
186+
Mapping{},
187+
llama::Array{alpaka::declareSharedVar<std::byte[Mapping{}.blobSize(Is)], Is>(acc)...}};
188+
}(std::make_index_sequence<Mapping::blobCount>{});
197189
}
198190
}();
199191

200192
const auto ti = alpaka::getIdx<alpaka::Grid, alpaka::Threads>(acc)[0];
201193
const auto tbi = alpaka::getIdx<alpaka::Block, alpaka::Threads>(acc)[0];
202194

203-
auto pis = llama::SimdN<typename View::RecordDim, Elems, MakeSizedBatch>{};
204-
llama::loadSimd(particles(ti * Elems), pis);
195+
auto pis = llama::SimdN<Particle, ElementsPerThread, MakeSizedBatch>{};
196+
llama::loadSimd(particles(ti * ElementsPerThread), pis);
205197

206-
for(int blockOffset = 0; blockOffset < problemSize; blockOffset += sharedElementsPerBlock)
198+
for(int blockOffset = 0; blockOffset < problemSize; blockOffset += SharedElementsPerBlock)
207199
{
208-
for(int j = 0; j < sharedElementsPerBlock; j += threadsPerBlock)
200+
for(int j = 0; j < SharedElementsPerBlock; j += ThreadsPerBlock)
209201
sharedView(j) = particles(blockOffset + tbi + j);
210202
alpaka::syncBlockThreads(acc);
211-
for(int j = 0; j < sharedElementsPerBlock; ++j)
203+
for(int j = 0; j < SharedElementsPerBlock; ++j)
212204
pPInteraction(acc, pis, sharedView(j));
213205
alpaka::syncBlockThreads(acc);
214206
}
215-
llama::storeSimd(pis(tag::Vel{}), particles(ti * Elems)(tag::Vel{}));
207+
llama::storeSimd(pis(tag::Vel{}), particles(ti * ElementsPerThread)(tag::Vel{}));
216208
}
217209
};
218210

219-
template<int Elems>
211+
template<int ElementsPerThread>
220212
struct MoveKernel
221213
{
222-
template<typename Acc, typename View>
223-
ALPAKA_FN_HOST_ACC void operator()(const Acc& acc, View particles) const
214+
ALPAKA_FN_HOST_ACC void operator()(const auto& acc, auto particles) const
224215
{
225216
const auto ti = alpaka::getIdx<alpaka::Grid, alpaka::Threads>(acc)[0];
226-
const auto i = ti * Elems;
227-
llama::SimdN<Vec3, Elems, MakeSizedBatch> pos;
228-
llama::SimdN<Vec3, Elems, MakeSizedBatch> vel;
217+
const auto i = ti * ElementsPerThread;
218+
llama::SimdN<Vec3, ElementsPerThread, MakeSizedBatch> pos;
219+
llama::SimdN<Vec3, ElementsPerThread, MakeSizedBatch> vel;
229220
llama::loadSimd(particles(i)(tag::Pos{}), pos);
230221
llama::loadSimd(particles(i)(tag::Vel{}), vel);
231222
llama::storeSimd(pos + vel * +timestep, particles(i)(tag::Pos{}));
@@ -354,7 +345,8 @@ void run(std::ostream& plotFile)
354345
{
355346
if constexpr(runUpdate)
356347
{
357-
auto updateKernel = UpdateKernel<elementsPerThread, QuotedMappingSM>{};
348+
auto updateKernel
349+
= UpdateKernel<threadsPerBlock, sharedElementsPerBlock, elementsPerThread, QuotedMappingSM>{};
358350
alpaka::exec<Acc>(queue, workdiv, updateKernel, llama::shallowCopy(accView));
359351
statsUpdate(watch.printAndReset("update", '\t'));
360352
}

0 commit comments

Comments
 (0)