@@ -30,13 +30,14 @@ template <typename LocalIndexType>
3030mpi::request RowGatherer<LocalIndexType>::apply_async(ptr_param<const LinOp> b,
3131 ptr_param<LinOp> x) const
3232{
33- return apply_async (b, x, send_workspace_ );
33+ return apply_async (b, x, send_cache_ );
3434}
3535
3636
3737template <typename LocalIndexType>
3838mpi::request RowGatherer<LocalIndexType>::apply_async(
39- ptr_param<const LinOp> b, ptr_param<LinOp> x, array<char >& workspace) const
39+ ptr_param<const LinOp> b, ptr_param<LinOp> x,
40+ gko::detail::GenericDenseCache& workspace) const
4041{
4142 auto ev = this ->apply_prepare (b, workspace);
4243 return this ->apply_finalize (b, x, ev, workspace);
@@ -46,13 +47,13 @@ template <typename LocalIndexType>
4647std::shared_ptr<const gko::detail::Event>
4748RowGatherer<LocalIndexType>::apply_prepare(ptr_param<const LinOp> b) const
4849{
49- return apply_prepare (b, send_workspace_ );
50+ return apply_prepare (b, send_cache_ );
5051}
5152
5253template <typename LocalIndexType>
5354std::shared_ptr<const gko::detail::Event>
54- RowGatherer<LocalIndexType>::apply_prepare(ptr_param< const LinOp> b,
55- array< char > & workspace) const
55+ RowGatherer<LocalIndexType>::apply_prepare(
56+ ptr_param< const LinOp> b, gko::detail::GenericDenseCache & workspace) const
5657{
5758 std::shared_ptr<const gko::detail::Event> ev = nullptr ;
5859 auto exec = this ->get_executor ();
@@ -79,24 +80,8 @@ RowGatherer<LocalIndexType>::apply_prepare(ptr_param<const LinOp> b,
7980
8081 dim<2 > send_size (coll_comm_->get_send_size (),
8182 b_local->get_size ()[1 ]);
82- auto send_size_in_bytes =
83- sizeof (ValueType) * send_size[0 ] * send_size[1 ];
84- // TODO: can not combine them to assignment because array
85- // assignment will copy the data to the place without
86- // changing executor.
87- if (!workspace.get_executor () ||
88- !mpi_exec->memory_accessible (workspace.get_executor ())) {
89- workspace.set_executor (mpi_exec);
90- }
91- if (send_size_in_bytes > workspace.get_size ()) {
92- workspace.resize_and_reset (send_size_in_bytes);
93- }
94- auto send_buffer = matrix::Dense<ValueType>::create (
95- mpi_exec, send_size,
96- make_array_view (
97- mpi_exec, send_size[0 ] * send_size[1 ],
98- reinterpret_cast <ValueType*>(workspace.get_data ())),
99- send_size[1 ]);
83+ auto send_buffer =
84+ workspace.get <ValueType>(mpi_exec, send_size);
10085 b_local->row_gather (&send_idxs_, send_buffer);
10186 b_local->get_executor ()->run (event::make_record_event (ev));
10287 });
@@ -110,14 +95,15 @@ mpi::request RowGatherer<LocalIndexType>::apply_finalize(
11095 ptr_param<const LinOp> b, ptr_param<LinOp> x,
11196 std::shared_ptr<const gko::detail::Event> ev) const
11297{
113- auto req = apply_finalize (b, x, ev, send_workspace_ );
98+ auto req = apply_finalize (b, x, ev, send_cache_ );
11499 return req;
115100}
116101
117102template <typename LocalIndexType>
118103mpi::request RowGatherer<LocalIndexType>::apply_finalize(
119104 ptr_param<const LinOp> b, ptr_param<LinOp> x,
120- std::shared_ptr<const gko::detail::Event> ev, array<char >& workspace) const
105+ std::shared_ptr<const gko::detail::Event> ev,
106+ gko::detail::GenericDenseCache& workspace) const
121107{
122108 mpi::request req;
123109
@@ -153,12 +139,8 @@ mpi::request RowGatherer<LocalIndexType>::apply_finalize(
153139
154140 dim<2 > send_size (coll_comm_->get_send_size (),
155141 b_local->get_size ()[1 ]);
156- auto send_buffer = matrix::Dense<ValueType>::create (
157- mpi_exec, send_size,
158- make_array_view (
159- mpi_exec, send_size[0 ] * send_size[1 ],
160- reinterpret_cast <ValueType*>(workspace.get_data ())),
161- send_size[1 ]);
142+ auto send_buffer =
143+ workspace.get <ValueType>(mpi_exec, send_size);
162144
163145 auto recv_ptr = x_global->get_local_values ();
164146 auto send_ptr = send_buffer->get_values ();
@@ -189,7 +171,7 @@ std::shared_ptr<const gko::detail::Event> apply_prepare(
189171template <typename LocalIndexType>
190172std::shared_ptr<const gko::detail::Event> apply_prepare (
191173 const RowGatherer<LocalIndexType>* rg, ptr_param<const LinOp> b,
192- array< char > & workspace)
174+ gko::detail::GenericDenseCache & workspace)
193175{
194176 return rg->apply_prepare (b, workspace);
195177}
@@ -208,7 +190,7 @@ template <typename LocalIndexType>
208190mpi::request apply_finalize (const RowGatherer<LocalIndexType>* rg,
209191 ptr_param<const LinOp> b, ptr_param<LinOp> x,
210192 std::shared_ptr<const gko::detail::Event> ev,
211- array< char > & workspace)
193+ gko::detail::GenericDenseCache & workspace)
212194{
213195 return rg->apply_finalize (b, x, ev, workspace);
214196}
@@ -218,9 +200,10 @@ mpi::request apply_finalize(const RowGatherer<LocalIndexType>* rg,
218200 std::shared_ptr<const gko::detail::Event> apply_prepare ( \
219201 const RowGatherer<IndexType>*, ptr_param<const LinOp>)
220202
221- #define GKO_DECLARE_TEST_APPLY_PREPARE_WORKSPACE (IndexType ) \
222- std::shared_ptr<const gko::detail::Event> apply_prepare ( \
223- const RowGatherer<IndexType>*, ptr_param<const LinOp>, array<char >&)
203+ #define GKO_DECLARE_TEST_APPLY_PREPARE_WORKSPACE (IndexType ) \
204+ std::shared_ptr<const gko::detail::Event> apply_prepare ( \
205+ const RowGatherer<IndexType>*, ptr_param<const LinOp>, \
206+ gko::detail::GenericDenseCache&)
224207
225208#define GKO_DECLARE_TEST_APPLY_FINALIZE (IndexType ) \
226209 mpi::request apply_finalize (const RowGatherer<IndexType>* rg, \
@@ -231,7 +214,7 @@ mpi::request apply_finalize(const RowGatherer<LocalIndexType>* rg,
231214 mpi::request apply_finalize (const RowGatherer<IndexType>* rg, \
232215 ptr_param<const LinOp> b, ptr_param<LinOp> x, \
233216 std::shared_ptr<const gko::detail::Event> ev, \
234- array< char > &)
217+ gko::detail::GenericDenseCache &)
235218
236219GKO_INSTANTIATE_FOR_EACH_INDEX_TYPE(GKO_DECLARE_TEST_APPLY_PREPARE );
237220GKO_INSTANTIATE_FOR_EACH_INDEX_TYPE (GKO_DECLARE_TEST_APPLY_PREPARE_WORKSPACE );
@@ -285,7 +268,7 @@ RowGatherer<LocalIndexType>::RowGatherer(
285268 imap.get_global_size ()}),
286269 coll_comm_(std::move(coll_comm)),
287270 send_idxs_(exec),
288- send_workspace_(exec )
271+ send_cache_( )
289272{
290273 // check that the coll_comm_ and imap have the same recv size
291274 // the same check for the send size is not possible, since the
@@ -365,17 +348,15 @@ RowGatherer<LocalIndexType>::RowGatherer(
365348 : EnablePolymorphicObject<RowGatherer>(exec),
366349 DistributedBase (coll_comm_template->get_base_communicator ()),
367350 coll_comm_(std::move(coll_comm_template)),
368- send_idxs_(exec),
369- send_workspace_(exec)
351+ send_idxs_(exec)
370352{}
371353
372354
373355template <typename LocalIndexType>
374356RowGatherer<LocalIndexType>::RowGatherer(RowGatherer&& o) noexcept
375357 : EnablePolymorphicObject<RowGatherer>(o.get_executor()),
376358 DistributedBase(o.get_communicator()),
377- send_idxs_(o.get_executor()),
378- send_workspace_(o.get_executor())
359+ send_idxs_(o.get_executor())
379360{
380361 *this = std::move (o);
381362}
@@ -404,7 +385,6 @@ RowGatherer<LocalIndexType>& RowGatherer<LocalIndexType>::operator=(
404385 o.coll_comm_ , mpi::detail::create_default_collective_communicator (
405386 o.get_communicator ()));
406387 send_idxs_ = std::move (o.send_idxs_ );
407- send_workspace_ = std::move (o.send_workspace_ );
408388 }
409389 return *this ;
410390}
0 commit comments