Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ target_sources(
base/device_matrix_data.cpp
base/executor.cpp
base/index_set.cpp
base/lin_op.cpp
base/memory.cpp
base/mpi.cpp
base/mtx_io.cpp
Expand Down Expand Up @@ -107,6 +108,7 @@ target_sources(
reorder/amd.cpp
reorder/mc64.cpp
reorder/rcm.cpp
reorder/reordered.cpp
reorder/scaled_reordered.cpp
solver/batch_bicgstab.cpp
solver/batch_cg.cpp
Expand Down
188 changes: 188 additions & 0 deletions core/base/lin_op.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
// SPDX-FileCopyrightText: 2025 The Ginkgo authors
//
// SPDX-License-Identifier: BSD-3-Clause

#include <ginkgo/core/base/lin_op.hpp>

namespace gko {


LinOp* LinOp::apply(ptr_param<const LinOp> b, ptr_param<LinOp> x)
{
this->template log<log::Logger::linop_apply_started>(this, b.get(),
x.get());
this->validate_application_parameters(b.get(), x.get());
auto exec = this->get_executor();
this->apply_impl(make_temporary_clone(exec, b).get(),
make_temporary_clone(exec, x).get());
this->template log<log::Logger::linop_apply_completed>(this, b.get(),
x.get());
return this;
}


const LinOp* LinOp::apply(ptr_param<const LinOp> b, ptr_param<LinOp> x) const
{
this->template log<log::Logger::linop_apply_started>(this, b.get(),
x.get());
this->validate_application_parameters(b.get(), x.get());
auto exec = this->get_executor();
this->apply_impl(make_temporary_clone(exec, b).get(),
make_temporary_clone(exec, x).get());
this->template log<log::Logger::linop_apply_completed>(this, b.get(),
x.get());
return this;
}


LinOp* LinOp::apply(ptr_param<const LinOp> alpha, ptr_param<const LinOp> b,
ptr_param<const LinOp> beta, ptr_param<LinOp> x)
{
this->template log<log::Logger::linop_advanced_apply_started>(
this, alpha.get(), b.get(), beta.get(), x.get());
this->validate_application_parameters(alpha.get(), b.get(), beta.get(),
x.get());
auto exec = this->get_executor();
this->apply_impl(make_temporary_clone(exec, alpha).get(),
make_temporary_clone(exec, b).get(),
make_temporary_clone(exec, beta).get(),
make_temporary_clone(exec, x).get());
this->template log<log::Logger::linop_advanced_apply_completed>(
this, alpha.get(), b.get(), beta.get(), x.get());
return this;
}


const LinOp* LinOp::apply(ptr_param<const LinOp> alpha,
ptr_param<const LinOp> b, ptr_param<const LinOp> beta,
ptr_param<LinOp> x) const
{
this->template log<log::Logger::linop_advanced_apply_started>(
this, alpha.get(), b.get(), beta.get(), x.get());
this->validate_application_parameters(alpha.get(), b.get(), beta.get(),
x.get());
auto exec = this->get_executor();
this->apply_impl(make_temporary_clone(exec, alpha).get(),
make_temporary_clone(exec, b).get(),
make_temporary_clone(exec, beta).get(),
make_temporary_clone(exec, x).get());
this->template log<log::Logger::linop_advanced_apply_completed>(
this, alpha.get(), b.get(), beta.get(), x.get());
return this;
}


LinOp& LinOp::operator=(const LinOp&) = default;


LinOp& LinOp::operator=(LinOp&& other)
{
if (this != &other) {
EnableAbstractPolymorphicObject<LinOp>::operator=(std::move(other));
this->set_size(other.get_size());
other.set_size({});
}
return *this;
}


LinOp::LinOp(const LinOp&) = default;


LinOp::LinOp(LinOp&& other)
: EnableAbstractPolymorphicObject<LinOp>(std::move(other)),
size_{std::exchange(other.size_, dim<2>{})}
{}


LinOp::LinOp(std::shared_ptr<const Executor> exec, const dim<2>& size)
: EnableAbstractPolymorphicObject<LinOp>(exec), size_{size}
{}


void LinOp::set_size(const dim<2>& value) noexcept { size_ = value; }


void LinOp::validate_application_parameters(const LinOp* b,
const LinOp* x) const
{
GKO_ASSERT_CONFORMANT(this, b);
GKO_ASSERT_EQUAL_ROWS(this, x);
GKO_ASSERT_EQUAL_COLS(b, x);
}


void LinOp::validate_application_parameters(const LinOp* alpha, const LinOp* b,
const LinOp* beta,
const LinOp* x) const
{
this->validate_application_parameters(b, x);
GKO_ASSERT_EQUAL_DIMENSIONS(alpha, dim<2>(1, 1));
GKO_ASSERT_EQUAL_DIMENSIONS(beta, dim<2>(1, 1));
}


LinOpFactory::ReuseData::ReuseData() = default;


LinOpFactory::ReuseData::~ReuseData() = default;


std::unique_ptr<LinOp> LinOpFactory::generate(
std::shared_ptr<const LinOp> input) const
{
this->template log<log::Logger::linop_factory_generate_started>(
this, input.get());
const auto exec = this->get_executor();
std::unique_ptr<LinOp> generated;
if (input->get_executor() == exec) {
generated = this->AbstractFactory::generate(input);
} else {
generated = this->AbstractFactory::generate(gko::clone(exec, input));
}
this->template log<log::Logger::linop_factory_generate_completed>(
this, input.get(), generated.get());
return generated;
}


std::unique_ptr<LinOpFactory::ReuseData> LinOpFactory::create_empty_reuse_data()
const
{
return std::make_unique<LinOpFactory::ReuseData>();
}


void LinOpFactory::check_reuse_consistent(const LinOp* /*input*/,
ReuseData& /*data*/) const
{}


std::unique_ptr<LinOp> LinOpFactory::generate_reuse(
std::shared_ptr<const LinOp> input, ReuseData& reuse_data) const
{
this->check_reuse_consistent(input.get(), reuse_data);
this->template log<log::Logger::linop_factory_generate_started>(
this, input.get());
const auto exec = this->get_executor();
std::unique_ptr<LinOp> generated;
if (input->get_executor() == exec) {
generated = this->generate_reuse_impl(input, reuse_data);
} else {
generated =
this->generate_reuse_impl(gko::clone(exec, input), reuse_data);
}
this->template log<log::Logger::linop_factory_generate_completed>(
this, input.get(), generated.get());
return generated;
}


std::unique_ptr<LinOp> LinOpFactory::generate_reuse_impl(
std::shared_ptr<const LinOp> input, ReuseData& /*reuse_data*/) const
{
return this->generate_impl(input);
}


} // namespace gko
143 changes: 123 additions & 20 deletions core/factorization/lu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,44 @@ Lu<ValueType, IndexType>::parse(const config::pnode& config,
}


template <typename ValueType, typename IndexType>
class Lu<ValueType, IndexType>::LuReuseData : public BaseReuseData {
friend class Lu;
friend class Factory;

public:
LuReuseData() = default;

bool is_empty() const { return lookup_data_ == nullptr; }

private:
std::unique_ptr<matrix::csr::lookup_data<IndexType>> lookup_data_;
std::unique_ptr<matrix::SparsityCsr<ValueType, IndexType>>
symbolic_factors_;
// diag_idxs are cheap to recompute, ignore them
};


template <typename ValueType, typename IndexType>
std::unique_ptr<LinOpFactory::ReuseData>
Lu<ValueType, IndexType>::create_empty_reuse_data() const
{
return std::make_unique<LuReuseData>();
}


template <typename ValueType, typename IndexType>
void Lu<ValueType, IndexType>::check_reuse_consistent(
const LinOp* input, BaseReuseData& reuse_data) const
{
auto& lrd = *as<LuReuseData>(&reuse_data);
if (lrd.is_empty()) {
return;
}
GKO_ASSERT_IS_SQUARE_MATRIX(input);
}


template <typename ValueType, typename IndexType>
Lu<ValueType, IndexType>::Lu(std::shared_ptr<const Executor> exec,
const parameters_type& params)
Expand All @@ -83,10 +121,44 @@ std::unique_ptr<Factorization<ValueType, IndexType>>
Lu<ValueType, IndexType>::generate(
std::shared_ptr<const LinOp> system_matrix) const
{
auto product =
std::unique_ptr<factorization_type>(static_cast<factorization_type*>(
this->LinOpFactory::generate(std::move(system_matrix)).release()));
return product;
return as<factorization_type>(
this->LinOpFactory::generate(std::move(system_matrix)));
}


template <typename ValueType, typename IndexType>
auto Lu<ValueType, IndexType>::generate_reuse(
std::shared_ptr<const LinOp> input, BaseReuseData& reuse_data) const
-> std::unique_ptr<factorization_type>
{
return as<factorization_type>(
LinOpFactory::generate_reuse(input, reuse_data));
}


template <typename ValueType, typename IndexType>
static std::unique_ptr<matrix::Csr<ValueType, IndexType>> symbolic_factorize(
symbolic_type algorithm, const matrix::Csr<ValueType, IndexType>* mtx)
{
auto exec = mtx->get_executor();
std::unique_ptr<matrix::Csr<ValueType, IndexType>> factors;
switch (algorithm) {
case symbolic_type::general:
exec->run(make_symbolic_lu(mtx, factors));
break;
case symbolic_type::near_symmetric:
exec->run(make_symbolic_lu_near_symm(mtx, factors));
break;
case symbolic_type::symmetric: {
std::unique_ptr<gko::factorization::elimination_forest<IndexType>>
forest;
exec->run(make_symbolic_cholesky(mtx, true, factors, forest));
break;
}
default:
GKO_INVALID_STATE("Invalid symbolic factorization algorithm");
}
return factors;
}


Expand All @@ -100,22 +172,7 @@ std::unique_ptr<LinOp> Lu<ValueType, IndexType>::generate_impl(
const auto num_rows = mtx->get_size()[0];
std::unique_ptr<matrix_type> factors;
if (!parameters_.symbolic_factorization) {
switch (parameters_.symbolic_algorithm) {
case symbolic_type::general:
exec->run(make_symbolic_lu(mtx.get(), factors));
break;
case symbolic_type::near_symmetric:
exec->run(make_symbolic_lu_near_symm(mtx.get(), factors));
break;
case symbolic_type::symmetric: {
std::unique_ptr<gko::factorization::elimination_forest<IndexType>>
forest;
exec->run(make_symbolic_cholesky(mtx.get(), true, factors, forest));
break;
}
default:
GKO_INVALID_STATE("Invalid symbolic factorization algorithm");
}
factors = symbolic_factorize(parameters_.symbolic_algorithm, mtx.get());
} else {
const auto& symbolic = parameters_.symbolic_factorization;
const auto factor_nnz = symbolic->get_num_nonzeros();
Expand Down Expand Up @@ -147,6 +204,52 @@ std::unique_ptr<LinOp> Lu<ValueType, IndexType>::generate_impl(
}


template <typename ValueType, typename IndexType>
std::unique_ptr<LinOp> Lu<ValueType, IndexType>::generate_reuse_impl(
std::shared_ptr<const LinOp> system_matrix, BaseReuseData& reuse_data) const
{
GKO_ASSERT_IS_SQUARE_MATRIX(system_matrix);
const auto exec = this->get_executor();
const auto mtx = copy_and_convert_to<matrix_type>(exec, system_matrix);
const auto num_rows = mtx->get_size()[0];
auto& lurd = *as<LuReuseData>(&reuse_data);
if (lurd.is_empty()) {
if (!parameters_.symbolic_factorization) {
auto tmp_factors =
symbolic_factorize(parameters_.symbolic_algorithm, mtx.get());
lurd.symbolic_factors_ = sparsity_pattern_type::create(exec);
tmp_factors->move_to(lurd.symbolic_factors_);
} else {
lurd.symbolic_factors_ =
parameters_.symbolic_factorization->clone();
}
// setup lookup structure on factors
lurd.lookup_data_ =
std::make_unique<matrix::csr::lookup_data<IndexType>>(
matrix::csr::build_lookup(lurd.symbolic_factors_.get()));
}
auto pattern = lurd.symbolic_factors_.get();
auto pattern_nnz = pattern->get_num_nonzeros();
auto factors = matrix_type::create(
exec, mtx->get_size(), array<ValueType>{exec, pattern_nnz},
make_array_view(exec, pattern_nnz, pattern->get_col_idxs()),
make_array_view(exec, num_rows + 1, pattern->get_row_ptrs()));
auto& lookup = *lurd.lookup_data_;
array<IndexType> diag_idxs{exec, num_rows};
exec->run(make_initialize(
mtx.get(), lookup.storage_offsets.get_const_data(),
lookup.row_descs.get_const_data(), lookup.storage.get_const_data(),
diag_idxs.get_data(), factors.get()));
// run numerical factorization
array<int> tmp{exec};
exec->run(make_factorize(
lookup.storage_offsets.get_const_data(),
lookup.row_descs.get_const_data(), lookup.storage.get_const_data(),
diag_idxs.get_const_data(), factors.get(), true, tmp));
return factorization_type::create_from_combined_lu(std::move(factors));
}


#define GKO_DECLARE_LU(ValueType, IndexType) class Lu<ValueType, IndexType>

GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(GKO_DECLARE_LU);
Expand Down
13 changes: 12 additions & 1 deletion core/matrix/csr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,18 @@ template <typename ValueType, typename IndexType>
void Csr<ValueType, IndexType>::move_to(
SparsityCsr<ValueType, IndexType>* result)
{
this->convert_to(result);
result->col_idxs_ = std::move(this->col_idxs_);
// create empty row_ptrs
result->row_ptrs_ = std::exchange(
this->row_ptrs_, array<IndexType>{this->get_executor(), {IndexType{}}});
if (!result->value_.get_data()) {
result->value_ =
array<ValueType>(result->get_executor(), {one<ValueType>()});
}
result->set_size(this->get_size());
this->set_size(dim<2>{});
this->values_.clear();
this->make_srow();
}


Expand Down
Loading
Loading