4141
4242#include < deal.II/matrix_free/portable_fe_evaluation.h>
4343#include < deal.II/matrix_free/portable_matrix_free.h>
44+ #include < deal.II/matrix_free/tools.h>
4445
4546// boost
4647#include < boost/algorithm/string.hpp>
@@ -308,9 +309,32 @@ class OperatorDealii : public OperatorBase<Number, MemorySpace::Default>
308309 * Compute inverse of diagonal.
309310 */
310311 void
311- compute_inverse_diagonal (VectorType &) const override
312+ compute_inverse_diagonal (VectorType &diagonal ) const override
312313 {
313- AssertThrow (false , ExcNotImplemented ());
314+ this ->initialize_dof_vector (diagonal);
315+
316+ const unsigned int n_components = dof_handler.get_fe ().n_components ();
317+ const unsigned int fe_degree = dof_handler.get_fe ().tensor_degree ();
318+ const unsigned int n_q_points_1d = quadrature.get_tensor_basis ()[0 ].size ();
319+
320+ if (n_components == 1 && fe_degree == 1 && n_q_points_1d == 2 )
321+ this ->compute_inverse_diagonal_internal <1 , 1 , 2 >(diagonal);
322+ else if (n_components == 1 && fe_degree == 2 && n_q_points_1d == 3 )
323+ this ->compute_inverse_diagonal_internal <1 , 2 , 3 >(diagonal);
324+ else if (n_components == dim && fe_degree == 1 && n_q_points_1d == 2 )
325+ this ->compute_inverse_diagonal_internal <dim, 1 , 2 >(diagonal);
326+ else if (n_components == dim && fe_degree == 2 && n_q_points_1d == 3 )
327+ this ->compute_inverse_diagonal_internal <dim, 2 , 3 >(diagonal);
328+ else if (n_components == 1 && fe_degree == 1 && n_q_points_1d == 3 )
329+ this ->compute_inverse_diagonal_internal <1 , 1 , 3 >(diagonal);
330+ else if (n_components == 1 && fe_degree == 2 && n_q_points_1d == 4 )
331+ this ->compute_inverse_diagonal_internal <1 , 2 , 4 >(diagonal);
332+ else if (n_components == dim && fe_degree == 1 && n_q_points_1d == 3 )
333+ this ->compute_inverse_diagonal_internal <dim, 1 , 3 >(diagonal);
334+ else if (n_components == dim && fe_degree == 2 && n_q_points_1d == 4 )
335+ this ->compute_inverse_diagonal_internal <dim, 2 , 4 >(diagonal);
336+ else
337+ AssertThrow (false , ExcInternalError ());
314338 }
315339
316340private:
@@ -334,6 +358,38 @@ class OperatorDealii : public OperatorBase<Number, MemorySpace::Default>
334358 }
335359 }
336360
361+ /* *
362+ * Templated compute_inverse_diagonal function.
363+ */
364+ template <int n_components, int fe_degree, int n_q_points_1d>
365+ void
366+ compute_inverse_diagonal_internal (VectorType &diagonal) const
367+ {
368+ if (bp <= BPType::BP2) // mass matrix
369+ {
370+ OperatorDealiiMassQuad<dim, fe_degree, n_q_points_1d, n_components, Number> op_quad;
371+
372+ MatrixFreeTools::compute_diagonal<dim, fe_degree, n_q_points_1d, n_components, Number>(
373+ matrix_free, diagonal, op_quad, EvaluationFlags::values, EvaluationFlags::values);
374+ }
375+ else
376+ {
377+ OperatorDealiiLaplaceQuad<dim, fe_degree, n_q_points_1d, n_components, Number> op_quad;
378+
379+ MatrixFreeTools::compute_diagonal<dim, fe_degree, n_q_points_1d, n_components, Number>(
380+ matrix_free, diagonal, op_quad, EvaluationFlags::gradients, EvaluationFlags::gradients);
381+ }
382+
383+
384+ Number *diagonal_ptr = diagonal.get_values ();
385+
386+ Kokkos::parallel_for (
387+ " lethe::invert_vector" ,
388+ Kokkos::RangePolicy<MemorySpace::Default::kokkos_space::execution_space>(
389+ 0 , diagonal.locally_owned_size ()),
390+ KOKKOS_LAMBDA (int i) { diagonal_ptr[i] = 1.0 / diagonal_ptr[i]; });
391+ }
392+
337393 /* *
338394 * Mapping object passed to the constructor.
339395 */
@@ -435,6 +491,10 @@ main(int argc, char *argv[])
435491 // create solver
436492 ReductionControl reduction_control (100 , 1e-20 , 1e-6 );
437493
494+ // create preconditioner
495+ DiagonalMatrix<VectorType> diagonal_matrix;
496+ op.compute_inverse_diagonal (diagonal_matrix.get_vector ());
497+
438498 std::chrono::time_point<std::chrono::system_clock> now;
439499
440500 bool not_converged = false ;
@@ -444,7 +504,7 @@ main(int argc, char *argv[])
444504 // solve problem
445505 SolverCG<VectorType> solver (reduction_control);
446506 now = std::chrono::system_clock::now ();
447- solver.solve (op, v, u, PreconditionIdentity () );
507+ solver.solve (op, v, u, diagonal_matrix );
448508 }
449509 catch (const SolverControl::NoConvergence &)
450510 {
0 commit comments