Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions benchmark/solver/distributed/solver.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2026 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand Down Expand Up @@ -42,7 +42,7 @@ struct Generator : public DistributedDefaultSystemGenerator<SolverGenerator> {
return Vec::create(
exec, comm, gko::dim<2>{system_matrix->get_size()[0], FLAGS_nrhs},
local_generator.generate_rhs(
exec, gko::as<Mtx>(system_matrix)->get_local_matrix().get(),
exec, gko::as<Mtx>(system_matrix)->get_diag_matrix().get(),
config));
}

Expand All @@ -53,7 +53,7 @@ struct Generator : public DistributedDefaultSystemGenerator<SolverGenerator> {
return Vec::create(
exec, comm, gko::dim<2>{rhs->get_size()[0], FLAGS_nrhs},
local_generator.generate_initial_guess(
exec, gko::as<Mtx>(system_matrix)->get_local_matrix().get(),
exec, gko::as<Mtx>(system_matrix)->get_diag_matrix().get(),
rhs->get_local_vector()));
}
};
Expand Down Expand Up @@ -84,8 +84,8 @@ int main(int argc, char* argv[])
Possible values for "stencil" are: 5pt (2D), 7pt (3D), 9pt (2D), 27pt (3D).
Optional values for "comm_pattern" are: stencil, optimal.
Possible values for "optimal[spmv]" follow the pattern
"<local_format>-<non_local_format>", where both "local_format" and
"non_local_format" can be any of the recognized spmv formats.
"<diag_format>-<off_diag_format>", where both "diag_format" and
"off_diag_format" can be any of the recognized spmv formats.
)";
std::string additional_json = R"(,"optimal":{"spmv":"csr-csr"})";
initialize_argument_parsing_matrix(&argc, &argv, header, format,
Expand Down
24 changes: 12 additions & 12 deletions benchmark/spmv/distributed/spmv.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2026 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand All @@ -24,11 +24,11 @@
#include "benchmark/utils/types.hpp"


DEFINE_string(local_formats, "csr",
"A comma-separated list of formats for the local matrix to run. "
DEFINE_string(diag_formats, "csr",
"A comma-separated list of formats for the diag matrix to run. "
"See the 'formats' option for a list of supported versions");
DEFINE_string(non_local_formats, "csr",
"A comma-separated list of formats for the non-local matrix to "
DEFINE_string(off_diag_formats, "csr",
"A comma-separated list of formats for the off-diag matrix to "
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

off-diag to offdiag directly?

"run. See the 'formats' option for a list of supported versions");


Expand All @@ -53,18 +53,18 @@ int main(int argc, char* argv[])

if (do_print) {
std::string extra_information =
"The formats are [" + FLAGS_local_formats + "]x[" +
FLAGS_non_local_formats + "]\n" +
"The formats are [" + FLAGS_diag_formats + "]x[" +
FLAGS_off_diag_formats + "]\n" +
"The number of right hand sides is " + std::to_string(FLAGS_nrhs);
print_general_information(extra_information, exec);
}

auto local_formats = split(FLAGS_local_formats, ',');
auto non_local_formats = split(FLAGS_non_local_formats, ',');
auto diag_formats = split(FLAGS_diag_formats, ',');
auto off_diag_formats = split(FLAGS_off_diag_formats, ',');
std::vector<std::string> formats;
for (const auto& local_fmt : local_formats) {
for (const auto& non_local_fmt : non_local_formats) {
formats.push_back(local_fmt + "-" + non_local_fmt);
for (const auto& diag_fmt : diag_formats) {
for (const auto& off_diag_fmt : off_diag_formats) {
formats.push_back(diag_fmt + "-" + off_diag_fmt);
}
}

Expand Down
8 changes: 4 additions & 4 deletions benchmark/utils/generator.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2026 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand Down Expand Up @@ -256,16 +256,16 @@ struct DistributedDefaultSystemGenerator {
format_name};
}

auto local_mat = formats::matrix_type_factory.at(formats[0])(exec);
auto non_local_mat = formats::matrix_type_factory.at(formats[1])(exec);
auto diag_mat = formats::matrix_type_factory.at(formats[0])(exec);
auto off_diag_mat = formats::matrix_type_factory.at(formats[1])(exec);

auto storage_logger = std::make_shared<StorageLogger>();
if (spmv_case) {
exec->add_logger(storage_logger);
}

auto dist_mat = dist_mtx<etype, itype, global_itype>::create(
exec, comm, local_mat, non_local_mat);
exec, comm, diag_mat, off_diag_mat);
dist_mat->read_distributed(data, part);

if (spmv_case) {
Expand Down
67 changes: 33 additions & 34 deletions common/cuda_hip/distributed/matrix_kernels.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2026 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand Down Expand Up @@ -50,19 +50,18 @@ struct input_type {


template <typename ValueType, typename LocalIndexType, typename GlobalIndexType>
void separate_local_nonlocal(
void separate_diag_off_diag(
std::shared_ptr<const DefaultExecutor> exec,
const device_matrix_data<ValueType, GlobalIndexType>& input,
const experimental::distributed::Partition<LocalIndexType, GlobalIndexType>*
row_partition,
const experimental::distributed::Partition<LocalIndexType, GlobalIndexType>*
col_partition,
experimental::distributed::comm_index_type local_part,
array<LocalIndexType>& local_row_idxs,
array<LocalIndexType>& local_col_idxs, array<ValueType>& local_values,
array<LocalIndexType>& non_local_row_idxs,
array<GlobalIndexType>& non_local_col_idxs,
array<ValueType>& non_local_values)
array<LocalIndexType>& diag_row_idxs, array<LocalIndexType>& diag_col_idxs,
array<ValueType>& diag_values, array<LocalIndexType>& off_diag_row_idxs,
array<GlobalIndexType>& off_diag_col_idxs,
array<ValueType>& off_diag_values)
{
auto input_vals = input.get_const_values();
auto row_part_ids = row_partition->get_part_ids();
Expand Down Expand Up @@ -93,9 +92,9 @@ void separate_local_nonlocal(
input_col_idxs + num_input_elements,
col_range_ids.get_data());

// count number of local<0> and non-local<1> elements. Since the input
// may contain non-local rows, we don't have
// num_local + num_non_local = num_elements and can't just count one of them
// count number of diag<0> and off-diag<1> elements. Since the input
// may contain rows not owned by this rank, we don't have
// num_diag + num_off_diag = num_elements and can't just count one of them
auto range_ids_it = thrust::make_zip_iterator(thrust::make_tuple(
row_range_ids.get_const_data(), col_range_ids.get_const_data()));
auto num_elements_pair = thrust::transform_reduce(
Expand All @@ -104,22 +103,22 @@ void separate_local_nonlocal(
const thrust::tuple<size_type, size_type>& tuple) {
auto row_part = row_part_ids[thrust::get<0>(tuple)];
auto col_part = col_part_ids[thrust::get<1>(tuple)];
bool is_inner_entry =
bool is_diag_entry =
row_part == local_part && col_part == local_part;
bool is_ghost_entry =
bool is_off_diag_entry =
row_part == local_part && col_part != local_part;
return thrust::make_tuple(
is_inner_entry ? size_type{1} : size_type{0},
is_ghost_entry ? size_type{1} : size_type{0});
is_diag_entry ? size_type{1} : size_type{0},
is_off_diag_entry ? size_type{1} : size_type{0});
},
thrust::make_tuple(size_type{}, size_type{}),
[] __host__ __device__(const thrust::tuple<size_type, size_type>& a,
const thrust::tuple<size_type, size_type>& b) {
return thrust::make_tuple(thrust::get<0>(a) + thrust::get<0>(b),
thrust::get<1>(a) + thrust::get<1>(b));
});
auto num_local_elements = thrust::get<0>(num_elements_pair);
auto num_non_local_elements = thrust::get<1>(num_elements_pair);
auto num_diag_elements = thrust::get<0>(num_elements_pair);
auto num_off_diag_elements = thrust::get<1>(num_elements_pair);

// define global-to-local maps for row and column indices
auto map_to_local_row =
Expand All @@ -143,23 +142,23 @@ void separate_local_nonlocal(
as_device_type(input.get_const_values()),
row_range_ids.get_const_data(), col_range_ids.get_const_data()));

// copy and transform local entries into arrays
local_row_idxs.resize_and_reset(num_local_elements);
local_col_idxs.resize_and_reset(num_local_elements);
local_values.resize_and_reset(num_local_elements);
auto local_it = thrust::make_transform_iterator(
// copy and transform diag entries into arrays
diag_row_idxs.resize_and_reset(num_diag_elements);
diag_col_idxs.resize_and_reset(num_diag_elements);
diag_values.resize_and_reset(num_diag_elements);
auto diag_it = thrust::make_transform_iterator(
input_it, [map_to_local_row, map_to_local_col] __host__ __device__(
const input_type input) {
auto local_row = map_to_local_row(input.row, input.row_range);
auto local_col = map_to_local_col(input.col, input.col_range);
return thrust::make_tuple(local_row, local_col, input.val);
});
thrust::copy_if(
policy, local_it, local_it + input.get_num_stored_elements(),
policy, diag_it, diag_it + input.get_num_stored_elements(),
range_ids_it,
thrust::make_zip_iterator(thrust::make_tuple(
local_row_idxs.get_data(), local_col_idxs.get_data(),
as_device_type(local_values.get_data()))),
diag_row_idxs.get_data(), diag_col_idxs.get_data(),
as_device_type(diag_values.get_data()))),
[local_part, row_part_ids, col_part_ids] __host__ __device__(
const thrust::tuple<size_type, size_type>& tuple) {
auto row_part = row_part_ids[thrust::get<0>(tuple)];
Expand All @@ -168,24 +167,24 @@ void separate_local_nonlocal(
});


// copy and transform non-local entries into arrays. this keeps global
// column indices, and also stores the column part id for each non-local
// copy and transform off-diag entries into arrays. this keeps global
// column indices, and also stores the column part id for each off-diag
// entry in an array
non_local_row_idxs.resize_and_reset(num_non_local_elements);
non_local_col_idxs.resize_and_reset(num_non_local_elements);
non_local_values.resize_and_reset(num_non_local_elements);
auto non_local_it = thrust::make_transform_iterator(
off_diag_row_idxs.resize_and_reset(num_off_diag_elements);
off_diag_col_idxs.resize_and_reset(num_off_diag_elements);
off_diag_values.resize_and_reset(num_off_diag_elements);
auto off_diag_it = thrust::make_transform_iterator(
input_it, [map_to_local_row,
col_part_ids] __host__ __device__(const input_type input) {
auto local_row = map_to_local_row(input.row, input.row_range);
return thrust::make_tuple(local_row, input.col, input.val);
});
thrust::copy_if(
policy, non_local_it, non_local_it + input.get_num_stored_elements(),
policy, off_diag_it, off_diag_it + input.get_num_stored_elements(),
range_ids_it,
thrust::make_zip_iterator(thrust::make_tuple(
non_local_row_idxs.get_data(), non_local_col_idxs.get_data(),
as_device_type(non_local_values.get_data()))),
off_diag_row_idxs.get_data(), off_diag_col_idxs.get_data(),
as_device_type(off_diag_values.get_data()))),
[local_part, row_part_ids, col_part_ids] __host__ __device__(
const thrust::tuple<size_type, size_type>& tuple) {
auto row_part = row_part_ids[thrust::get<0>(tuple)];
Expand All @@ -195,7 +194,7 @@ void separate_local_nonlocal(
}

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE(
GKO_DECLARE_SEPARATE_LOCAL_NONLOCAL);
GKO_DECLARE_SEPARATE_DIAG_OFF_DIAG);


} // namespace distributed_matrix
Expand Down
4 changes: 2 additions & 2 deletions core/device_hooks/common_kernels.inc.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2026 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand Down Expand Up @@ -389,7 +389,7 @@ namespace distributed_matrix {


GKO_STUB_VALUE_AND_LOCAL_GLOBAL_INDEX_TYPE_BASE(
GKO_DECLARE_SEPARATE_LOCAL_NONLOCAL);
GKO_DECLARE_SEPARATE_DIAG_OFF_DIAG);


} // namespace distributed_matrix
Expand Down
4 changes: 2 additions & 2 deletions core/distributed/helpers.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors
// SPDX-FileCopyrightText: 2017 - 2026 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

Expand Down Expand Up @@ -193,7 +193,7 @@ inline const LinOp* get_local(const LinOp* mtx)
#if GINKGO_BUILD_MPI
if (is_distributed(mtx)) {
return run_matrix(mtx, [](auto concrete) {
return concrete->get_local_matrix().get();
return concrete->get_diag_matrix().get();
});
}
#endif
Expand Down
Loading
Loading