|
25 | 25 | #include "dpcpp/base/math.hpp" |
26 | 26 | #include "dpcpp/base/onemkl_bindings.hpp" |
27 | 27 | #include "dpcpp/base/types.hpp" |
| 28 | +#include "dpcpp/components/atomic.dp.hpp" |
28 | 29 | #include "dpcpp/components/cooperative_groups.dp.hpp" |
29 | 30 | #include "dpcpp/components/reduction.dp.hpp" |
30 | 31 | #include "dpcpp/components/thread_ids.dp.hpp" |
@@ -589,7 +590,28 @@ void scatter_add(std::shared_ptr<const DpcppExecutor> exec, |
589 | 590 | matrix::view::dense<const ValueType> source, |
590 | 591 | matrix::view::dense<ValueType> target) |
591 | 592 | { |
592 | | - GKO_NOT_IMPLEMENTED; |
| 593 | + auto nrows = source.size[0]; |
| 594 | + auto ncols = source.size[1]; |
| 595 | + if (nrows == 0 || ncols == 0) { |
| 596 | + return; |
| 597 | + } |
| 598 | + auto total = nrows * ncols; |
| 599 | + auto queue = exec->get_queue(); |
| 600 | + // Use const copies for capture |
| 601 | + auto src_vals = source.values; |
| 602 | + auto src_stride = source.stride; |
| 603 | + auto tgt_vals = target.values; |
| 604 | + auto tgt_stride = target.stride; |
| 605 | + queue->submit([&](sycl::handler& cgh) { |
| 606 | + cgh.parallel_for(sycl::range<1>(total), [=](sycl::id<1> idx_id) { |
| 607 | + auto idx = idx_id[0]; |
| 608 | + auto row = idx / ncols; |
| 609 | + auto col = idx % ncols; |
| 610 | + auto target_row = static_cast<size_type>(scatter_indices[row]); |
| 611 | + atomic_add(tgt_vals + target_row * tgt_stride + col, |
| 612 | + src_vals[row * src_stride + col]); |
| 613 | + }); |
| 614 | + }); |
593 | 615 | } |
594 | 616 |
|
595 | 617 | GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE( |
|
0 commit comments