Skip to content

Commit 51d146e

Browse files
committed
add dpcpp kernel
1 parent 47ccf63 commit 51d146e

1 file changed

Lines changed: 23 additions & 1 deletion

File tree

dpcpp/matrix/dense_kernels.dp.cpp

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "dpcpp/base/math.hpp"
2626
#include "dpcpp/base/onemkl_bindings.hpp"
2727
#include "dpcpp/base/types.hpp"
28+
#include "dpcpp/components/atomic.dp.hpp"
2829
#include "dpcpp/components/cooperative_groups.dp.hpp"
2930
#include "dpcpp/components/reduction.dp.hpp"
3031
#include "dpcpp/components/thread_ids.dp.hpp"
@@ -589,7 +590,28 @@ void scatter_add(std::shared_ptr<const DpcppExecutor> exec,
589590
matrix::view::dense<const ValueType> source,
590591
matrix::view::dense<ValueType> target)
591592
{
592-
GKO_NOT_IMPLEMENTED;
593+
auto nrows = source.size[0];
594+
auto ncols = source.size[1];
595+
if (nrows == 0 || ncols == 0) {
596+
return;
597+
}
598+
auto total = nrows * ncols;
599+
auto queue = exec->get_queue();
600+
// Use const copies for capture
601+
auto src_vals = source.values;
602+
auto src_stride = source.stride;
603+
auto tgt_vals = target.values;
604+
auto tgt_stride = target.stride;
605+
queue->submit([&](sycl::handler& cgh) {
606+
cgh.parallel_for(sycl::range<1>(total), [=](sycl::id<1> idx_id) {
607+
auto idx = idx_id[0];
608+
auto row = idx / ncols;
609+
auto col = idx % ncols;
610+
auto target_row = static_cast<size_type>(scatter_indices[row]);
611+
atomic_add(tgt_vals + target_row * tgt_stride + col,
612+
src_vals[row * src_stride + col]);
613+
});
614+
});
593615
}
594616

595617
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(

0 commit comments

Comments
 (0)