@@ -10,7 +10,8 @@ namespace rxmesh {
1010template <typename SpMatT, int DenseMatOrder = Eigen::ColMajor>
1111struct cuDSSCholeskySolver : public DirectSolver <SpMatT, DenseMatOrder>
1212{
13- using T = typename SpMatT::Type;
13+ using IndexT = typename SpMatT::IndexT;
14+ using T = typename SpMatT::Type;
1415
1516 cuDSSCholeskySolver ()
1617 : DirectSolver<SpMatT, DenseMatOrder>(), m_first_pre_solve(false )
@@ -46,11 +47,12 @@ struct cuDSSCholeskySolver : public DirectSolver<SpMatT, DenseMatOrder>
4647 * @brief pre_solve should be called before calling the solve() method.
4748 * and it should be called every time the matrix is updated
4849 */
49- virtual void pre_solve (DenseMatrix<T, DenseMatOrder>& B_mat,
50+ virtual void pre_solve (RXMeshStatic& rx,
51+ DenseMatrix<T, DenseMatOrder>& B_mat,
5052 DenseMatrix<T, DenseMatOrder>& X_mat)
5153 {
5254 if (m_first_pre_solve) {
53- permute (B_mat, X_mat);
55+ permute (rx, B_mat, X_mat);
5456 analyze_pattern ();
5557 factorize ();
5658 m_first_pre_solve = false ;
@@ -62,7 +64,8 @@ struct cuDSSCholeskySolver : public DirectSolver<SpMatT, DenseMatOrder>
6264 /* *
6365 * @brief permute the matrix to reduce the non-zero fill-in
6466 */
65- virtual void permute (DenseMatrix<T, DenseMatOrder>& B_mat,
67+ virtual void permute (RXMeshStatic& rx,
68+ DenseMatrix<T, DenseMatOrder>& B_mat,
6669 DenseMatrix<T, DenseMatOrder>& X_mat)
6770 {
6871 if (m_first_pre_solve) {
@@ -88,8 +91,18 @@ struct cuDSSCholeskySolver : public DirectSolver<SpMatT, DenseMatOrder>
8891 reorder_alg = CUDSS_ALG_1;
8992 break ;
9093 }
94+ case PermuteMethod::GPUND: {
95+ reorder_alg = CUDSS_ALG_DEFAULT;
96+ DirectSolver<SpMatT, DenseMatOrder>::permute (rx);
97+ CUDSS_ERROR (
98+ cudssDataSet (m_cudss_handle,
99+ m_cudss_data,
100+ CUDSS_DATA_USER_PERM,
101+ this ->m_d_permute ,
102+ size_t (m_mat->rows () * sizeof (IndexT))));
103+ break ;
104+ }
91105 case PermuteMethod::GPUMGND:
92- case PermuteMethod::GPUND:
93106 default :
94107 reorder_alg = CUDSS_ALG_DEFAULT;
95108 }
0 commit comments