Skip to content

Commit 09fa620

Browse files
committed
Compute diagonal
1 parent 788b606 commit 09fa620

File tree

1 file changed

+63
-3
lines changed

1 file changed

+63
-3
lines changed

examples/deal.II/bps_02.cc

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141

4242
#include <deal.II/matrix_free/portable_fe_evaluation.h>
4343
#include <deal.II/matrix_free/portable_matrix_free.h>
44+
#include <deal.II/matrix_free/tools.h>
4445

4546
// boost
4647
#include <boost/algorithm/string.hpp>
@@ -308,9 +309,32 @@ class OperatorDealii : public OperatorBase<Number, MemorySpace::Default>
308309
* Compute inverse of diagonal.
309310
*/
310311
void
311-
compute_inverse_diagonal(VectorType &) const override
312+
compute_inverse_diagonal(VectorType &diagonal) const override
312313
{
313-
AssertThrow(false, ExcNotImplemented());
314+
this->initialize_dof_vector(diagonal);
315+
316+
const unsigned int n_components = dof_handler.get_fe().n_components();
317+
const unsigned int fe_degree = dof_handler.get_fe().tensor_degree();
318+
const unsigned int n_q_points_1d = quadrature.get_tensor_basis()[0].size();
319+
320+
if (n_components == 1 && fe_degree == 1 && n_q_points_1d == 2)
321+
this->compute_inverse_diagonal_internal<1, 1, 2>(diagonal);
322+
else if (n_components == 1 && fe_degree == 2 && n_q_points_1d == 3)
323+
this->compute_inverse_diagonal_internal<1, 2, 3>(diagonal);
324+
else if (n_components == dim && fe_degree == 1 && n_q_points_1d == 2)
325+
this->compute_inverse_diagonal_internal<dim, 1, 2>(diagonal);
326+
else if (n_components == dim && fe_degree == 2 && n_q_points_1d == 3)
327+
this->compute_inverse_diagonal_internal<dim, 2, 3>(diagonal);
328+
else if (n_components == 1 && fe_degree == 1 && n_q_points_1d == 3)
329+
this->compute_inverse_diagonal_internal<1, 1, 3>(diagonal);
330+
else if (n_components == 1 && fe_degree == 2 && n_q_points_1d == 4)
331+
this->compute_inverse_diagonal_internal<1, 2, 4>(diagonal);
332+
else if (n_components == dim && fe_degree == 1 && n_q_points_1d == 3)
333+
this->compute_inverse_diagonal_internal<dim, 1, 3>(diagonal);
334+
else if (n_components == dim && fe_degree == 2 && n_q_points_1d == 4)
335+
this->compute_inverse_diagonal_internal<dim, 2, 4>(diagonal);
336+
else
337+
AssertThrow(false, ExcInternalError());
314338
}
315339

316340
private:
@@ -334,6 +358,38 @@ class OperatorDealii : public OperatorBase<Number, MemorySpace::Default>
334358
}
335359
}
336360

361+
/**
362+
* Templated compute_inverse_diagonal function.
363+
*/
364+
template <int n_components, int fe_degree, int n_q_points_1d>
365+
void
366+
compute_inverse_diagonal_internal(VectorType &diagonal) const
367+
{
368+
if (bp <= BPType::BP2) // mass matrix
369+
{
370+
OperatorDealiiMassQuad<dim, fe_degree, n_q_points_1d, n_components, Number> op_quad;
371+
372+
MatrixFreeTools::compute_diagonal<dim, fe_degree, n_q_points_1d, n_components, Number>(
373+
matrix_free, diagonal, op_quad, EvaluationFlags::values, EvaluationFlags::values);
374+
}
375+
else
376+
{
377+
OperatorDealiiLaplaceQuad<dim, fe_degree, n_q_points_1d, n_components, Number> op_quad;
378+
379+
MatrixFreeTools::compute_diagonal<dim, fe_degree, n_q_points_1d, n_components, Number>(
380+
matrix_free, diagonal, op_quad, EvaluationFlags::gradients, EvaluationFlags::gradients);
381+
}
382+
383+
384+
Number *diagonal_ptr = diagonal.get_values();
385+
386+
Kokkos::parallel_for(
387+
"lethe::invert_vector",
388+
Kokkos::RangePolicy<MemorySpace::Default::kokkos_space::execution_space>(
389+
0, diagonal.locally_owned_size()),
390+
KOKKOS_LAMBDA(int i) { diagonal_ptr[i] = 1.0 / diagonal_ptr[i]; });
391+
}
392+
337393
/**
338394
* Mapping object passed to the constructor.
339395
*/
@@ -435,6 +491,10 @@ main(int argc, char *argv[])
435491
// create solver
436492
ReductionControl reduction_control(100, 1e-20, 1e-6);
437493

494+
// create preconditioner
495+
DiagonalMatrix<VectorType> diagonal_matrix;
496+
op.compute_inverse_diagonal(diagonal_matrix.get_vector());
497+
438498
std::chrono::time_point<std::chrono::system_clock> now;
439499

440500
bool not_converged = false;
@@ -444,7 +504,7 @@ main(int argc, char *argv[])
444504
// solve problem
445505
SolverCG<VectorType> solver(reduction_control);
446506
now = std::chrono::system_clock::now();
447-
solver.solve(op, v, u, PreconditionIdentity());
507+
solver.solve(op, v, u, diagonal_matrix);
448508
}
449509
catch (const SolverControl::NoConvergence &)
450510
{

0 commit comments

Comments
 (0)