Skip to content

Commit d15ec2d

Browse files
authored
Merge Use GenericDenseCache to manage workspace
This PR replaces the `array<char>` workspace in the distributed `RowGatherer` with a `GenericDenseCache`. This simplifies the allocation and resizing of the workspace. Related PR: #1996
2 parents 88f6b7b + d91f6a6 commit d15ec2d

3 files changed

Lines changed: 48 additions & 105 deletions

File tree

core/distributed/row_gatherer.cpp

Lines changed: 23 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,14 @@ template <typename LocalIndexType>
3030
mpi::request RowGatherer<LocalIndexType>::apply_async(ptr_param<const LinOp> b,
3131
ptr_param<LinOp> x) const
3232
{
33-
return apply_async(b, x, send_workspace_);
33+
return apply_async(b, x, send_cache_);
3434
}
3535

3636

3737
template <typename LocalIndexType>
3838
mpi::request RowGatherer<LocalIndexType>::apply_async(
39-
ptr_param<const LinOp> b, ptr_param<LinOp> x, array<char>& workspace) const
39+
ptr_param<const LinOp> b, ptr_param<LinOp> x,
40+
gko::detail::GenericDenseCache& workspace) const
4041
{
4142
auto ev = this->apply_prepare(b, workspace);
4243
return this->apply_finalize(b, x, ev, workspace);
@@ -46,13 +47,13 @@ template <typename LocalIndexType>
4647
std::shared_ptr<const gko::detail::Event>
4748
RowGatherer<LocalIndexType>::apply_prepare(ptr_param<const LinOp> b) const
4849
{
49-
return apply_prepare(b, send_workspace_);
50+
return apply_prepare(b, send_cache_);
5051
}
5152

5253
template <typename LocalIndexType>
5354
std::shared_ptr<const gko::detail::Event>
54-
RowGatherer<LocalIndexType>::apply_prepare(ptr_param<const LinOp> b,
55-
array<char>& workspace) const
55+
RowGatherer<LocalIndexType>::apply_prepare(
56+
ptr_param<const LinOp> b, gko::detail::GenericDenseCache& workspace) const
5657
{
5758
std::shared_ptr<const gko::detail::Event> ev = nullptr;
5859
auto exec = this->get_executor();
@@ -79,24 +80,8 @@ RowGatherer<LocalIndexType>::apply_prepare(ptr_param<const LinOp> b,
7980

8081
dim<2> send_size(coll_comm_->get_send_size(),
8182
b_local->get_size()[1]);
82-
auto send_size_in_bytes =
83-
sizeof(ValueType) * send_size[0] * send_size[1];
84-
// TODO: can not combine them to assignment because array
85-
// assignment will copy the data to the place without
86-
// changing executor.
87-
if (!workspace.get_executor() ||
88-
!mpi_exec->memory_accessible(workspace.get_executor())) {
89-
workspace.set_executor(mpi_exec);
90-
}
91-
if (send_size_in_bytes > workspace.get_size()) {
92-
workspace.resize_and_reset(send_size_in_bytes);
93-
}
94-
auto send_buffer = matrix::Dense<ValueType>::create(
95-
mpi_exec, send_size,
96-
make_array_view(
97-
mpi_exec, send_size[0] * send_size[1],
98-
reinterpret_cast<ValueType*>(workspace.get_data())),
99-
send_size[1]);
83+
auto send_buffer =
84+
workspace.get<ValueType>(mpi_exec, send_size);
10085
b_local->row_gather(&send_idxs_, send_buffer);
10186
b_local->get_executor()->run(event::make_record_event(ev));
10287
});
@@ -110,14 +95,15 @@ mpi::request RowGatherer<LocalIndexType>::apply_finalize(
11095
ptr_param<const LinOp> b, ptr_param<LinOp> x,
11196
std::shared_ptr<const gko::detail::Event> ev) const
11297
{
113-
auto req = apply_finalize(b, x, ev, send_workspace_);
98+
auto req = apply_finalize(b, x, ev, send_cache_);
11499
return req;
115100
}
116101

117102
template <typename LocalIndexType>
118103
mpi::request RowGatherer<LocalIndexType>::apply_finalize(
119104
ptr_param<const LinOp> b, ptr_param<LinOp> x,
120-
std::shared_ptr<const gko::detail::Event> ev, array<char>& workspace) const
105+
std::shared_ptr<const gko::detail::Event> ev,
106+
gko::detail::GenericDenseCache& workspace) const
121107
{
122108
mpi::request req;
123109

@@ -153,12 +139,8 @@ mpi::request RowGatherer<LocalIndexType>::apply_finalize(
153139

154140
dim<2> send_size(coll_comm_->get_send_size(),
155141
b_local->get_size()[1]);
156-
auto send_buffer = matrix::Dense<ValueType>::create(
157-
mpi_exec, send_size,
158-
make_array_view(
159-
mpi_exec, send_size[0] * send_size[1],
160-
reinterpret_cast<ValueType*>(workspace.get_data())),
161-
send_size[1]);
142+
auto send_buffer =
143+
workspace.get<ValueType>(mpi_exec, send_size);
162144

163145
auto recv_ptr = x_global->get_local_values();
164146
auto send_ptr = send_buffer->get_values();
@@ -189,7 +171,7 @@ std::shared_ptr<const gko::detail::Event> apply_prepare(
189171
template <typename LocalIndexType>
190172
std::shared_ptr<const gko::detail::Event> apply_prepare(
191173
const RowGatherer<LocalIndexType>* rg, ptr_param<const LinOp> b,
192-
array<char>& workspace)
174+
gko::detail::GenericDenseCache& workspace)
193175
{
194176
return rg->apply_prepare(b, workspace);
195177
}
@@ -208,7 +190,7 @@ template <typename LocalIndexType>
208190
mpi::request apply_finalize(const RowGatherer<LocalIndexType>* rg,
209191
ptr_param<const LinOp> b, ptr_param<LinOp> x,
210192
std::shared_ptr<const gko::detail::Event> ev,
211-
array<char>& workspace)
193+
gko::detail::GenericDenseCache& workspace)
212194
{
213195
return rg->apply_finalize(b, x, ev, workspace);
214196
}
@@ -218,9 +200,10 @@ mpi::request apply_finalize(const RowGatherer<LocalIndexType>* rg,
218200
std::shared_ptr<const gko::detail::Event> apply_prepare( \
219201
const RowGatherer<IndexType>*, ptr_param<const LinOp>)
220202

221-
#define GKO_DECLARE_TEST_APPLY_PREPARE_WORKSPACE(IndexType) \
222-
std::shared_ptr<const gko::detail::Event> apply_prepare( \
223-
const RowGatherer<IndexType>*, ptr_param<const LinOp>, array<char>&)
203+
#define GKO_DECLARE_TEST_APPLY_PREPARE_WORKSPACE(IndexType) \
204+
std::shared_ptr<const gko::detail::Event> apply_prepare( \
205+
const RowGatherer<IndexType>*, ptr_param<const LinOp>, \
206+
gko::detail::GenericDenseCache&)
224207

225208
#define GKO_DECLARE_TEST_APPLY_FINALIZE(IndexType) \
226209
mpi::request apply_finalize(const RowGatherer<IndexType>* rg, \
@@ -231,7 +214,7 @@ mpi::request apply_finalize(const RowGatherer<LocalIndexType>* rg,
231214
mpi::request apply_finalize(const RowGatherer<IndexType>* rg, \
232215
ptr_param<const LinOp> b, ptr_param<LinOp> x, \
233216
std::shared_ptr<const gko::detail::Event> ev, \
234-
array<char>&)
217+
gko::detail::GenericDenseCache&)
235218

236219
GKO_INSTANTIATE_FOR_EACH_INDEX_TYPE(GKO_DECLARE_TEST_APPLY_PREPARE);
237220
GKO_INSTANTIATE_FOR_EACH_INDEX_TYPE(GKO_DECLARE_TEST_APPLY_PREPARE_WORKSPACE);
@@ -285,7 +268,7 @@ RowGatherer<LocalIndexType>::RowGatherer(
285268
imap.get_global_size()}),
286269
coll_comm_(std::move(coll_comm)),
287270
send_idxs_(exec),
288-
send_workspace_(exec)
271+
send_cache_()
289272
{
290273
// check that the coll_comm_ and imap have the same recv size
291274
// the same check for the send size is not possible, since the
@@ -365,17 +348,15 @@ RowGatherer<LocalIndexType>::RowGatherer(
365348
: EnablePolymorphicObject<RowGatherer>(exec),
366349
DistributedBase(coll_comm_template->get_base_communicator()),
367350
coll_comm_(std::move(coll_comm_template)),
368-
send_idxs_(exec),
369-
send_workspace_(exec)
351+
send_idxs_(exec)
370352
{}
371353

372354

373355
template <typename LocalIndexType>
374356
RowGatherer<LocalIndexType>::RowGatherer(RowGatherer&& o) noexcept
375357
: EnablePolymorphicObject<RowGatherer>(o.get_executor()),
376358
DistributedBase(o.get_communicator()),
377-
send_idxs_(o.get_executor()),
378-
send_workspace_(o.get_executor())
359+
send_idxs_(o.get_executor())
379360
{
380361
*this = std::move(o);
381362
}
@@ -404,7 +385,6 @@ RowGatherer<LocalIndexType>& RowGatherer<LocalIndexType>::operator=(
404385
o.coll_comm_, mpi::detail::create_default_collective_communicator(
405386
o.get_communicator()));
406387
send_idxs_ = std::move(o.send_idxs_);
407-
send_workspace_ = std::move(o.send_workspace_);
408388
}
409389
return *this;
410390
}

include/ginkgo/core/distributed/row_gatherer.hpp

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
1+
// SPDX-FileCopyrightText: 2017 - 2026 The Ginkgo authors
22
//
33
// SPDX-License-Identifier: BSD-3-Clause
44

@@ -42,7 +42,7 @@ std::shared_ptr<const gko::detail::Event> apply_prepare(
4242
template <typename LocalIndexType>
4343
std::shared_ptr<const gko::detail::Event> apply_prepare(
4444
const RowGatherer<LocalIndexType>* rg, ptr_param<const LinOp> b,
45-
array<char>& workspace);
45+
gko::detail::GenericDenseCache& workspace);
4646

4747
// give access to test function on protected function
4848
template <typename LocalIndexType>
@@ -55,7 +55,7 @@ template <typename LocalIndexType>
5555
mpi::request apply_finalize(const RowGatherer<LocalIndexType>* rg,
5656
ptr_param<const LinOp> b, ptr_param<LinOp> x,
5757
std::shared_ptr<const gko::detail::Event>,
58-
array<char>& workspace);
58+
gko::detail::GenericDenseCache& workspace);
5959

6060

6161
} // namespace detail
@@ -99,16 +99,16 @@ class RowGatherer final
9999
friend std::shared_ptr<const gko::detail::Event>
100100
detail::apply_prepare<LocalIndexType>(const RowGatherer* rg,
101101
ptr_param<const LinOp> b);
102-
friend std::shared_ptr<const gko::detail::Event>
103-
detail::apply_prepare<LocalIndexType>(const RowGatherer* rg,
104-
ptr_param<const LinOp> b,
105-
array<char>& workspace);
102+
friend std::shared_ptr<const gko::detail::Event> detail::apply_prepare<
103+
LocalIndexType>(const RowGatherer* rg, ptr_param<const LinOp> b,
104+
gko::detail::GenericDenseCache& workspace);
106105
friend mpi::request detail::apply_finalize<LocalIndexType>(
107106
const RowGatherer* rg, ptr_param<const LinOp> b, ptr_param<LinOp> x,
108107
std::shared_ptr<const gko::detail::Event>);
109108
friend mpi::request detail::apply_finalize<LocalIndexType>(
110109
const RowGatherer* rg, ptr_param<const LinOp> b, ptr_param<LinOp> x,
111-
std::shared_ptr<const gko::detail::Event>, array<char>& workspace);
110+
std::shared_ptr<const gko::detail::Event>,
111+
gko::detail::GenericDenseCache& workspace);
112112

113113
public:
114114
/**
@@ -147,9 +147,9 @@ class RowGatherer final
147147
* @return a mpi::request for this task. The task is guaranteed to
148148
* be completed only after `.wait()` has been called on it.
149149
*/
150-
[[nodiscard]] mpi::request apply_async(ptr_param<const LinOp> b,
151-
ptr_param<LinOp> x,
152-
array<char>& workspace) const;
150+
[[nodiscard]] mpi::request apply_async(
151+
ptr_param<const LinOp> b, ptr_param<LinOp> x,
152+
gko::detail::GenericDenseCache& workspace) const;
153153

154154
/**
155155
* Returns the size of the row gatherer.
@@ -243,15 +243,17 @@ class RowGatherer final
243243
ptr_param<const LinOp> b) const;
244244

245245
std::shared_ptr<const gko::detail::Event> apply_prepare(
246-
ptr_param<const LinOp> b, array<char>& workspace) const;
246+
ptr_param<const LinOp> b,
247+
gko::detail::GenericDenseCache& workspace) const;
247248

248249
mpi::request apply_finalize(
249250
ptr_param<const LinOp> b, ptr_param<LinOp> x,
250251
std::shared_ptr<const gko::detail::Event>) const;
251252

252-
mpi::request apply_finalize(ptr_param<const LinOp> b, ptr_param<LinOp> x,
253-
std::shared_ptr<const gko::detail::Event>,
254-
array<char>& workspace) const;
253+
mpi::request apply_finalize(
254+
ptr_param<const LinOp> b, ptr_param<LinOp> x,
255+
std::shared_ptr<const gko::detail::Event>,
256+
gko::detail::GenericDenseCache& workspace) const;
255257

256258
private:
257259
/**
@@ -281,7 +283,7 @@ class RowGatherer final
281283
dim<2> size_;
282284
std::shared_ptr<const mpi::CollectiveCommunicator> coll_comm_;
283285
array<LocalIndexType> send_idxs_;
284-
mutable array<char> send_workspace_;
286+
mutable gko::detail::GenericDenseCache send_cache_;
285287
};
286288

287289

test/mpi/distributed/row_gatherer.cpp

Lines changed: 7 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
1+
// SPDX-FileCopyrightText: 2017 - 2026 The Ginkgo authors
22
//
33
// SPDX-License-Identifier: BSD-3-Clause
44

@@ -142,7 +142,7 @@ TYPED_TEST(RowGatherer, CanApplyAsyncWithWorkspace)
142142
auto x = Vector::create(this->mpi_exec, this->comm,
143143
gko::dim<2>{this->rg->get_size()[0], 1},
144144
gko::dim<2>{expected.get_size(), 1});
145-
gko::array<char> workspace;
145+
gko::detail::GenericDenseCache workspace;
146146

147147
auto req = this->rg->apply_async(b, x, workspace);
148148
req.wait();
@@ -153,7 +153,6 @@ TYPED_TEST(RowGatherer, CanApplyAsyncWithWorkspace)
153153
expected, 1));
154154
GKO_ASSERT_MTX_NEAR(x->get_local_vector(), expected_vec->get_local_vector(),
155155
0.0);
156-
ASSERT_GT(workspace.get_size(), 0);
157156
}
158157

159158

@@ -173,8 +172,8 @@ TYPED_TEST(RowGatherer, CanApplyAsyncMultipleTimesWithWorkspace)
173172
gko::dim<2>{this->rg->get_size()[0], 1},
174173
gko::dim<2>{expected.get_size(), 1});
175174
auto x2 = gko::clone(x1);
176-
gko::array<char> workspace1;
177-
gko::array<char> workspace2;
175+
gko::detail::GenericDenseCache workspace1;
176+
gko::detail::GenericDenseCache workspace2;
178177

179178
auto req1 = this->rg->apply_async(b1, x1, workspace1);
180179
auto req2 = this->rg->apply_async(b2, x2, workspace2);
@@ -299,7 +298,7 @@ TYPED_TEST(RowGatherer, CanApplyAsyncWithEventAndWorkspace)
299298
auto x = Vector::create(this->mpi_exec, this->comm,
300299
gko::dim<2>{this->rg->get_size()[0], 1},
301300
gko::dim<2>{expected.get_size(), 1});
302-
gko::array<char> workspace;
301+
gko::detail::GenericDenseCache workspace;
303302

304303
auto ev = apply_prepare(this->rg.get(), b, workspace);
305304
auto req = apply_finalize(this->rg.get(), b, x, ev, workspace);
@@ -311,7 +310,6 @@ TYPED_TEST(RowGatherer, CanApplyAsyncWithEventAndWorkspace)
311310
expected, 1));
312311
GKO_ASSERT_MTX_NEAR(x->get_local_vector(), expected_vec->get_local_vector(),
313312
0.0);
314-
ASSERT_GT(workspace.get_size(), 0);
315313
}
316314

317315

@@ -331,8 +329,8 @@ TYPED_TEST(RowGatherer, CanApplyAsyncMultipleTimesWithEventAndWorkspace)
331329
gko::dim<2>{this->rg->get_size()[0], 1},
332330
gko::dim<2>{expected.get_size(), 1});
333331
auto x2 = gko::clone(x1);
334-
gko::array<char> workspace1;
335-
gko::array<char> workspace2;
332+
gko::detail::GenericDenseCache workspace1;
333+
gko::detail::GenericDenseCache workspace2;
336334

337335
auto ev1 = apply_prepare(this->rg.get(), b1, workspace1);
338336
auto ev2 = apply_prepare(this->rg.get(), b2, workspace2);
@@ -354,43 +352,6 @@ TYPED_TEST(RowGatherer, CanApplyAsyncMultipleTimesWithEventAndWorkspace)
354352
}
355353

356354

357-
TYPED_TEST(
358-
RowGatherer,
359-
CanApplyAsyncWithEventAndWorkspaceEnsuringPrepareAndFinalizeSeparately)
360-
{
361-
using Dense = gko::matrix::Dense<double>;
362-
using Vector = gko::experimental::distributed::Vector<double>;
363-
int rank = this->comm.rank();
364-
auto offset = static_cast<double>(rank * 3);
365-
auto b = Vector::create(
366-
this->exec, this->comm, gko::dim<2>{18, 1},
367-
gko::initialize<Dense>({offset, offset + 1, offset + 2}, this->exec));
368-
auto expected = this->template create_recv_connections<double>()[rank];
369-
auto modified_expected =
370-
gko::array<double>(expected.get_executor(), expected.get_size());
371-
modified_expected.fill(0.0);
372-
auto x = Vector::create(this->mpi_exec, this->comm,
373-
gko::dim<2>{this->rg->get_size()[0], 1},
374-
gko::dim<2>{expected.get_size(), 1});
375-
gko::array<char> workspace;
376-
377-
auto ev = apply_prepare(this->rg.get(), b, workspace);
378-
// we modify the workspace to all 0
379-
workspace.fill(static_cast<char>(0));
380-
this->exec->synchronize();
381-
auto req = apply_finalize(this->rg.get(), b, x, ev, workspace);
382-
req.wait();
383-
384-
auto expected_vec = Vector::create(
385-
this->mpi_exec, this->comm, gko::dim<2>{this->rg->get_size()[0], 1},
386-
Dense::create(this->mpi_exec,
387-
gko::dim<2>{modified_expected.get_size(), 1},
388-
modified_expected, 1));
389-
GKO_ASSERT_MTX_NEAR(x->get_local_vector(), expected_vec->get_local_vector(),
390-
0.0);
391-
}
392-
393-
394355
TYPED_TEST(RowGatherer, ThrowsOnNonMatchingExecutor)
395356
{
396357
if (this->mpi_exec == this->exec) {

0 commit comments

Comments
 (0)