@@ -32,7 +32,12 @@ struct HessianSparseMatrix : public SparseMatrix<T>
3232 HessianSparseMatrix (const RXMeshStatic& rx,
3333 const int extra_nnz_entries,
3434 Op op = Op::VV)
35- : SparseMatrix<T>(rx, 1 .0f , extra_nnz_entries, op, K)
35+ : SparseMatrix<T>(rx,
36+ 1 .0f ,
37+ extra_nnz_entries,
38+ op,
39+ detail::BlockDim (K, K),
40+ true )
3641 {
3742 }
3843
@@ -49,9 +54,9 @@ struct HessianSparseMatrix : public SparseMatrix<T>
4954 assert (is_non_zero (row_v, col_v));
5055
5156 const IndexT r_id =
52- this ->get_row_id (row_v) * this ->m_replicate + local_i;
57+ this ->get_row_id (row_v) * this ->m_block_dim . x + local_i;
5358 const IndexT c_id =
54- this ->get_row_id (col_v) * this ->m_replicate + local_j;
59+ this ->get_row_id (col_v) * this ->m_block_dim . y + local_j;
5560
5661 return SparseMatrix<T>::operator ()(r_id, c_id);
5762 }
@@ -68,9 +73,9 @@ struct HessianSparseMatrix : public SparseMatrix<T>
6873 assert (is_non_zero (row_v, col_v));
6974
7075 const IndexT r_id =
71- this ->get_row_id (row_v) * this ->m_replicate + local_i;
76+ this ->get_row_id (row_v) * this ->m_block_dim . x + local_i;
7277 const IndexT c_id =
73- this ->get_row_id (col_v) * this ->m_replicate + local_j;
78+ this ->get_row_id (col_v) * this ->m_block_dim . y + local_j;
7479
7580 return SparseMatrix<T>::operator ()(r_id, c_id);
7681 }
@@ -88,9 +93,9 @@ struct HessianSparseMatrix : public SparseMatrix<T>
8893 assert (is_non_zero (row_v, col_v));
8994
9095 const IndexT r_id =
91- this ->get_row_id (row_v) * this ->m_replicate + local_i;
96+ this ->get_row_id (row_v) * this ->m_block_dim . x + local_i;
9297 const IndexT c_id =
93- this ->get_row_id (col_v) * this ->m_replicate + local_j;
98+ this ->get_row_id (col_v) * this ->m_block_dim . y + local_j;
9499
95100 return {r_id, c_id};
96101 }
@@ -102,8 +107,8 @@ struct HessianSparseMatrix : public SparseMatrix<T>
102107 __device__ __host__ const bool is_non_zero (const VertexHandle& row_v,
103108 const VertexHandle& col_v) const
104109 {
105- const IndexT r_id = this ->get_row_id (row_v) * this ->m_replicate ;
106- const IndexT c_id = this ->get_row_id (col_v) * this ->m_replicate ;
110+ const IndexT r_id = this ->get_row_id (row_v) * this ->m_block_dim . x ;
111+ const IndexT c_id = this ->get_row_id (col_v) * this ->m_block_dim . y ;
107112
108113 return SparseMatrix<T>::is_non_zero (r_id, c_id);
109114 }
0 commit comments