Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 141 additions & 4 deletions src/smith/differentiable_numerics/differentiable_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,18 +266,155 @@ std::vector<DifferentiableBlockSolver::FieldPtr> LinearDifferentiableBlockSolver
return u_duals;
}

NonlinearDifferentiableBlockSolver::NonlinearDifferentiableBlockSolver(std::unique_ptr<EquationSolver> s)
: nonlinear_solver_(std::move(s))
{
}

void NonlinearDifferentiableBlockSolver::completeSetup(const std::vector<FieldT>&)
{
// initializeSolver(&nonlinear_solver_->preconditioner(), u);
}

std::vector<DifferentiableBlockSolver::FieldPtr> NonlinearDifferentiableBlockSolver::solve(
const std::vector<FieldPtr>& u_guesses,
std::function<std::vector<mfem::Vector>(const std::vector<FieldPtr>&)> residual_funcs,
std::function<std::vector<std::vector<MatrixPtr>>(const std::vector<FieldPtr>&)> jacobian_funcs) const
{
SMITH_MARK_FUNCTION;

int num_rows = static_cast<int>(u_guesses.size());
SLIC_ERROR_IF(num_rows < 0, "Number of residual rows must be non-negative");

mfem::Array<int> block_offsets;
block_offsets.SetSize(num_rows + 1);
block_offsets[0] = 0;
for (int row_i = 0; row_i < num_rows; ++row_i) {
block_offsets[row_i + 1] = u_guesses[static_cast<size_t>(row_i)]->space().TrueVSize();
}
block_offsets.PartialSum();

auto block_u = std::make_unique<mfem::BlockVector>(block_offsets);
for (int row_i = 0; row_i < num_rows; ++row_i) {
block_u->GetBlock(row_i) = *u_guesses[static_cast<size_t>(row_i)];
}

auto block_r = std::make_unique<mfem::BlockVector>(block_offsets);

auto residual_op_ = std::make_unique<mfem_ext::StdFunctionOperator>(
block_u->Size(),
[&residual_funcs, num_rows, &u_guesses, &block_r](const mfem::Vector& u_, mfem::Vector& r_) {
const mfem::BlockVector* u = dynamic_cast<const mfem::BlockVector*>(&u_);
SLIC_ERROR_IF(!u, "Invalid u cast in block differentiable solver to a blocl vector");
for (int row_i = 0; row_i < num_rows; ++row_i) {
*u_guesses[static_cast<size_t>(row_i)] = u->GetBlock(row_i);
}
auto residuals = residual_funcs(u_guesses);
// auto block_r = std::make_unique<mfem::BlockVector>(block_offsets);
// auto block_r = dynamic_cast<mfem::BlockVector*>(&r_);
SLIC_ERROR_IF(!block_r, "Invalid r cast in block differentiable solver to a block vector");
for (int row_i = 0; row_i < num_rows; ++row_i) {
auto r = residuals[static_cast<size_t>(row_i)];
block_r->GetBlock(row_i) = r;
}
r_ = *block_r;
},
[this, &block_offsets, &u_guesses, jacobian_funcs, num_rows](const mfem::Vector& u_) -> mfem::Operator& {
const mfem::BlockVector* u = dynamic_cast<const mfem::BlockVector*>(&u_);
SLIC_ERROR_IF(!u, "Invalid u cast in block differentiable solver to a block vector");
for (int row_i = 0; row_i < num_rows; ++row_i) {
*u_guesses[static_cast<size_t>(row_i)] = u->GetBlock(row_i);
}
block_jac_ = std::make_unique<mfem::BlockOperator>(block_offsets);
matrix_of_jacs_ = jacobian_funcs(u_guesses);
for (int i = 0; i < num_rows; ++i) {
for (int j = 0; j < num_rows; ++j) {
auto& J = matrix_of_jacs_[static_cast<size_t>(i)][static_cast<size_t>(j)];
if (J) {
block_jac_->SetBlock(i, j, J.get());
}
}
}
return *block_jac_;
});
nonlinear_solver_->setOperator(*residual_op_);
nonlinear_solver_->solve(*block_u);

for (int row_i = 0; row_i < num_rows; ++row_i) {
*u_guesses[static_cast<size_t>(row_i)] = block_u->GetBlock(row_i);
}

return u_guesses;
}

std::vector<DifferentiableBlockSolver::FieldPtr> NonlinearDifferentiableBlockSolver::solveAdjoint(
const std::vector<DualPtr>& u_bars, std::vector<std::vector<MatrixPtr>>& jacobian_transposed) const
{
SMITH_MARK_FUNCTION;

int num_rows = static_cast<int>(u_bars.size());
SLIC_ERROR_IF(num_rows < 0, "Number of residual rows must be non-negative");

std::vector<DifferentiableBlockSolver::FieldPtr> u_duals(static_cast<size_t>(num_rows));
for (int row_i = 0; row_i < num_rows; ++row_i) {
u_duals[static_cast<size_t>(row_i)] = std::make_shared<DifferentiableBlockSolver::FieldT>(
u_bars[static_cast<size_t>(row_i)]->space(), "u_dual_" + std::to_string(row_i));
}

mfem::Array<int> block_offsets;
block_offsets.SetSize(num_rows + 1);
block_offsets[0] = 0;
for (int row_i = 0; row_i < num_rows; ++row_i) {
block_offsets[row_i + 1] = u_bars[static_cast<size_t>(row_i)]->space().TrueVSize();
}
block_offsets.PartialSum();

auto block_ds = std::make_unique<mfem::BlockVector>(block_offsets);
*block_ds = 0.0;

auto block_r = std::make_unique<mfem::BlockVector>(block_offsets);
for (int row_i = 0; row_i < num_rows; ++row_i) {
block_r->GetBlock(row_i) = *u_bars[static_cast<size_t>(row_i)];
}

auto block_jac = std::make_unique<mfem::BlockOperator>(block_offsets);
for (int i = 0; i < num_rows; ++i) {
for (int j = 0; j < num_rows; ++j) {
block_jac->SetBlock(i, j, jacobian_transposed[static_cast<size_t>(i)][static_cast<size_t>(j)].get());
}
}

auto& linear_solver = nonlinear_solver_->linearSolver();
linear_solver.SetOperator(*block_jac);
linear_solver.Mult(*block_r, *block_ds);

for (int row_i = 0; row_i < num_rows; ++row_i) {
*u_duals[static_cast<size_t>(row_i)] = block_ds->GetBlock(row_i);
}

return u_duals;
}

std::shared_ptr<LinearDifferentiableSolver> buildDifferentiableLinearSolver(LinearSolverOptions linear_opts,
const smith::Mesh& mesh)
{
auto [linear_solver, precond] = smith::buildLinearSolverAndPreconditioner(linear_opts, mesh.getComm());
return std::make_shared<smith::LinearDifferentiableSolver>(std::move(linear_solver), std::move(precond));
}

std::shared_ptr<NonlinearDifferentiableSolver> buildDifferentiableNonlinearSolver(
smith::NonlinearSolverOptions nonlinear_opts, LinearSolverOptions linear_opts, const smith::Mesh& mesh)
std::shared_ptr<NonlinearDifferentiableSolver> buildDifferentiableNonlinearSolver(NonlinearSolverOptions nonlinear_opts,
LinearSolverOptions linear_opts,
const smith::Mesh& mesh)
{
auto solid_solver = std::make_unique<EquationSolver>(nonlinear_opts, linear_opts, mesh.getComm());
return std::make_shared<NonlinearDifferentiableSolver>(std::move(solid_solver));
}

std::shared_ptr<NonlinearDifferentiableBlockSolver> buildDifferentiableNonlinearBlockSolver(
NonlinearSolverOptions nonlinear_opts, LinearSolverOptions linear_opts, const smith::Mesh& mesh)
{
auto solid_solver = std::make_unique<smith::EquationSolver>(nonlinear_opts, linear_opts, mesh.getComm());
return std::make_shared<smith::NonlinearDifferentiableSolver>(std::move(solid_solver));
auto solid_solver = std::make_unique<EquationSolver>(nonlinear_opts, linear_opts, mesh.getComm());
return std::make_shared<NonlinearDifferentiableBlockSolver>(std::move(solid_solver));
}

} // namespace smith
38 changes: 38 additions & 0 deletions src/smith/differentiable_numerics/differentiable_solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ namespace mfem {
class Solver;
class Vector;
class HypreParMatrix;
class BlockOperator;
} // namespace mfem

namespace smith {
Expand Down Expand Up @@ -175,6 +176,36 @@ class LinearDifferentiableBlockSolver : public DifferentiableBlockSolver {
mutable std::unique_ptr<mfem::Solver> mfem_preconditioner; ///< stored mfem block preconditioner
};

/// @brief Implementation of the DifferentiableBlockSolver interface for the special case of nonlinear solves with
/// linear adjoint solves
class NonlinearDifferentiableBlockSolver : public DifferentiableBlockSolver {
public:
/// @brief Construct from a linear solver and linear block precondition which may be used by the linear solver
NonlinearDifferentiableBlockSolver(std::unique_ptr<EquationSolver> s);

/// @overload
void completeSetup(const std::vector<FieldT>& us) override;

/// @overload
std::vector<FieldPtr> solve(
const std::vector<FieldPtr>& u_guesses,
std::function<std::vector<mfem::Vector>(const std::vector<FieldPtr>&)> residuals,
std::function<std::vector<std::vector<MatrixPtr>>(const std::vector<FieldPtr>&)> jacobians) const override;

/// @overload
std::vector<FieldPtr> solveAdjoint(const std::vector<DualPtr>& u_bars,
std::vector<std::vector<MatrixPtr>>& jacobian_transposed) const override;

mutable std::unique_ptr<mfem::BlockOperator>
block_jac_; ///< Need to hold an instance of a block operator to work with the mfem solver interface
mutable std::vector<std::vector<MatrixPtr>>
matrix_of_jacs_; ///< Holding vectors of block matrices to that do not going out of scope before the mfem solver
///< is done with using them in the block_jac_

mutable std::unique_ptr<EquationSolver>
nonlinear_solver_; ///< the nonlinear equation solver used for the forward pass
};

/// @brief Create a differentiable linear solver
/// @param linear_opts linear options struct
/// @param mesh mesh
Expand All @@ -189,4 +220,11 @@ std::shared_ptr<NonlinearDifferentiableSolver> buildDifferentiableNonlinearSolve
LinearSolverOptions linear_opts,
const smith::Mesh& mesh);

/// @brief Create a differentiable nonlinear block solver
/// @param nonlinear_opts nonlinear options struct
/// @param linear_opts linear options struct
/// @param mesh mesh
std::shared_ptr<NonlinearDifferentiableBlockSolver> buildDifferentiableNonlinearBlockSolver(
NonlinearSolverOptions nonlinear_opts, LinearSolverOptions linear_opts, const smith::Mesh& mesh);

} // namespace smith
Original file line number Diff line number Diff line change
Expand Up @@ -72,39 +72,22 @@ class DirichletBoundaryConditions {

/// @brief Specify time and space varying Dirichlet boundary conditions over a domain.
/// @param domain All dofs in this domain have boundary conditions applied to it.
/// @param component component direction to apply boundary condition to if the underlying field is a vector-field.
/// @param applied_displacement applied_displacement is a functor which takes time, and a
/// smith::tensor<double,spatial_dim> corresponding to the spatial coordinate. The functor must return a double. For
/// example: [](double t, smith::tensor<double, dim> X) { return 1.0; }
template <int spatial_dim, typename AppliedDisplacementFunction>
void setVectorBCs(const Domain& domain, int component, AppliedDisplacementFunction applied_displacement)
void setScalarBCs(const Domain& domain, AppliedDisplacementFunction applied_displacement)
{
const int field_dim = space_.GetVDim();
SLIC_ERROR_IF(component >= field_dim,
axom::fmt::format("Trying to set boundary conditions on a field with dim {}, using component {}",
field_dim, component));
auto mfem_coefficient_function = [applied_displacement](const mfem::Vector& X_mfem, double t) {
auto X = make_tensor<spatial_dim>([&X_mfem](int k) { return X_mfem[k]; });
return applied_displacement(t, X);
};

auto dof_list = domain.dof_list(&space_);
// scalar ldofs -> vector ldofs
space_.DofsToVDofs(component, dof_list);
space_.DofsToVDofs(static_cast<int>(0), dof_list);

auto component_disp_bdr_coef_ = std::make_shared<mfem::FunctionCoefficient>(mfem_coefficient_function);
bcs_.addEssential(dof_list, component_disp_bdr_coef_, space_, component);
}

/// @brief Specify time and space varying Dirichlet boundary conditions over a domain.
/// @param domain All dofs in this domain have boundary conditions applied to it.
/// @param applied_displacement applied_displacement is a functor which takes time, and a
/// smith::tensor<double,spatial_dim> corresponding to the spatial coordinate. The functor must return a double. For
/// example: [](double t, smith::tensor<double, dim> X) { return 1.0; }
template <int spatial_dim, typename AppliedDisplacementFunction>
void setScalarBCs(const Domain& domain, AppliedDisplacementFunction applied_displacement)
{
setScalarBCs<spatial_dim>(domain, 0, applied_displacement);
bcs_.addEssential(dof_list, component_disp_bdr_coef_, space_, 0);
}

/// @brief Constrain the dofs of a scalar field over a domain
Expand Down
67 changes: 43 additions & 24 deletions src/smith/differentiable_numerics/nonlinear_solve.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ std::vector<FieldState> block_solve(const std::vector<WeakForm*>& residual_evals
const std::vector<std::vector<FieldState>>& states,
const std::vector<std::vector<FieldState>>& params, const TimeInfo& time_info,
const DifferentiableBlockSolver* solver,
const std::vector<BoundaryConditionManager*> bc_managers)
const std::vector<const BoundaryConditionManager*> bc_managers)
{
SMITH_MARK_FUNCTION;
size_t num_rows_ = residual_evals.size();
Expand Down Expand Up @@ -305,7 +305,6 @@ std::vector<FieldState> block_solve(const std::vector<WeakForm*>& residual_evals
}
}
allFields.push_back(shape_disp);

struct ZeroDualVectors {
std::vector<FEDualPtr> operator()(const std::vector<FEFieldPtr>& fs)
{
Expand All @@ -319,15 +318,14 @@ std::vector<FieldState> block_solve(const std::vector<WeakForm*>& residual_evals

FieldVecState sol =
shape_disp.create_state<std::vector<FEFieldPtr>, std::vector<FEDualPtr>>(allFields, ZeroDualVectors());

sol.set_eval([=](const gretl::UpstreamStates& upstreams, gretl::DownstreamState& downstream) {
SMITH_MARK_BEGIN("solve forward");
const size_t num_rows = num_state_inputs.size();
std::vector<std::vector<FEFieldPtr>> input_fields(num_rows);
SLIC_ERROR_IF(num_rows != num_param_inputs.size(), "row count for params and columns are inconsistent");
SLIC_ERROR_IF(num_rows != num_param_inputs.size(), "row count for params and states are inconsistent");

// The order of inputs in upstreams seems to be:
// states of residual 0 -> states of residual 1 -> params of residual 0 -> params of residual 1
// The order of inputs in upstreams is:
// states of residual 0, states of residual 1, ... , params of residual 0, params of residual 1, ...
size_t field_count = 0;
for (size_t row_i = 0; row_i < num_rows; ++row_i) {
for (size_t state_i = 0; state_i < num_state_inputs[row_i]; ++state_i) {
Expand All @@ -342,12 +340,23 @@ std::vector<FieldState> block_solve(const std::vector<WeakForm*>& residual_evals

std::vector<FEFieldPtr> diagonal_fields(num_rows);
for (size_t row_i = 0; row_i < num_rows; ++row_i) {
diagonal_fields[row_i] = std::make_shared<FiniteElementState>(*input_fields[row_i][block_indices[row_i][row_i]]);
size_t prime_unknown_row_i = block_indices[row_i][row_i];
SLIC_ERROR_IF(prime_unknown_row_i == invalid_block_index,
"The primary unknown field (field index for block_index[n][n], must not be invalid)");
diagonal_fields[row_i] = std::make_shared<FiniteElementState>(*input_fields[row_i][prime_unknown_row_i]);
}

for (size_t row_i = 0; row_i < num_rows; ++row_i) {
FEFieldPtr primal_field_row_i = diagonal_fields[row_i];
applyBoundaryConditions(time_info.time(), bc_managers[row_i], primal_field_row_i, nullptr);
}

for (size_t row_i = 0; row_i < num_rows; ++row_i) {
for (size_t col_j = 0; col_j < num_rows; ++col_j) {
input_fields[row_i][block_indices[row_i][col_j]] = diagonal_fields[col_j];
size_t prime_unknown_ij = block_indices[row_i][col_j];
if (prime_unknown_ij != invalid_block_index) {
input_fields[row_i][block_indices[row_i][col_j]] = diagonal_fields[col_j];
}
}
}

Expand All @@ -363,13 +372,11 @@ std::vector<FieldState> block_solve(const std::vector<WeakForm*>& residual_evals
*primal_field_row_i = *unknowns[row_i];
applyBoundaryConditions(time_info.time(), bc_managers[row_i], primal_field_row_i, nullptr);
}

for (size_t row_i = 0; row_i < num_rows; ++row_i) {
residuals[row_i] = residual_evals[row_i]->residual(time_info, shape_disp_ptr.get(),
getConstFieldPointers(input_fields[row_i]));
residuals[row_i].SetSubVector(bc_managers[row_i]->allEssentialTrueDofs(), 0.0);
}

return residuals;
};

Expand All @@ -389,29 +396,37 @@ std::vector<FieldState> block_solve(const std::vector<WeakForm*>& residual_evals
std::vector<double> tangent_weights(row_field_inputs.size(), 0.0);
for (size_t col_j = 0; col_j < num_rows; ++col_j) {
size_t field_index_to_diff = block_indices[row_i][col_j];
tangent_weights[field_index_to_diff] = 1.0;
auto jac_ij = residual_evals[row_i]->jacobian(time_info, shape_disp_ptr.get(),
getConstFieldPointers(row_field_inputs), tangent_weights);
jacobians[row_i].emplace_back(std::move(jac_ij));
tangent_weights[field_index_to_diff] = 0.0;
if (field_index_to_diff != invalid_block_index) {
tangent_weights[field_index_to_diff] = 1.0;
auto jac_ij = residual_evals[row_i]->jacobian(time_info, shape_disp_ptr.get(),
getConstFieldPointers(row_field_inputs), tangent_weights);
jacobians[row_i].emplace_back(std::move(jac_ij));
tangent_weights[field_index_to_diff] = 0.0;
} else {
jacobians[row_i].emplace_back(nullptr);
}
}
}

// Apply BCs to the block system
for (size_t row_i = 0; row_i < num_rows; ++row_i) {
mfem::HypreParMatrix* Jii =
jacobians[row_i][row_i]->EliminateRowsCols(bc_managers[row_i]->allEssentialTrueDofs());
delete Jii;
if (jacobians[row_i][row_i]) {
jacobians[row_i][row_i]->EliminateBC(bc_managers[row_i]->allEssentialTrueDofs(),
mfem::Operator::DiagonalPolicy::DIAG_ONE);
}
for (size_t col_j = 0; col_j < num_rows; ++col_j) {
if (col_j != row_i) {
jacobians[row_i][col_j]->EliminateRows(bc_managers[row_i]->allEssentialTrueDofs());
mfem::HypreParMatrix* Jji =
jacobians[col_j][row_i]->EliminateCols(bc_managers[row_i]->allEssentialTrueDofs());
delete Jji;
if (jacobians[row_i][col_j]) {
jacobians[row_i][col_j]->EliminateRows(bc_managers[row_i]->allEssentialTrueDofs());
}
if (jacobians[col_j][row_i]) {
mfem::HypreParMatrix* Jji =
jacobians[col_j][row_i]->EliminateCols(bc_managers[row_i]->allEssentialTrueDofs());
delete Jji;
}
}
}
}

return jacobians;
};

Expand Down Expand Up @@ -558,7 +573,11 @@ std::vector<FieldState> block_solve(const std::vector<WeakForm*>& residual_evals
for (size_t i = 0; i < num_rows_; ++i) {
FieldState s = gretl::create_state<FEFieldPtr, FEDualPtr>(
zero_dual_from_state(),
[i](const std::vector<FEFieldPtr>& sols) { return std::make_shared<FiniteElementState>(*sols[i]); },
[i](const std::vector<FEFieldPtr>& sols) {
auto state_copy = std::make_shared<FiniteElementState>(sols[i]->space(), sols[i]->name());
*state_copy = *sols[i];
return state_copy;
},
[i](const std::vector<FEFieldPtr>&, const FEFieldPtr&, std::vector<FEDualPtr>& sols_,
const FEDualPtr& output_) { *sols_[i] += *output_; },
sol);
Expand Down
Loading