Skip to content

Commit 7a717ad

Browse files
committed
generalize the blocks in sparse matrix to non-square blocks
1 parent dddfa93 commit 7a717ad

File tree

5 files changed

+122
-252
lines changed

5 files changed

+122
-252
lines changed

include/rxmesh/diff/hessian_sparse_matrix.h

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,12 @@ struct HessianSparseMatrix : public SparseMatrix<T>
3232
HessianSparseMatrix(const RXMeshStatic& rx,
3333
const int extra_nnz_entries,
3434
Op op = Op::VV)
35-
: SparseMatrix<T>(rx, 1.0f, extra_nnz_entries, op, K)
35+
: SparseMatrix<T>(rx,
36+
1.0f,
37+
extra_nnz_entries,
38+
op,
39+
detail::BlockDim(K, K),
40+
true)
3641
{
3742
}
3843

@@ -49,9 +54,9 @@ struct HessianSparseMatrix : public SparseMatrix<T>
4954
assert(is_non_zero(row_v, col_v));
5055

5156
const IndexT r_id =
52-
this->get_row_id(row_v) * this->m_replicate + local_i;
57+
this->get_row_id(row_v) * this->m_block_dim.x + local_i;
5358
const IndexT c_id =
54-
this->get_row_id(col_v) * this->m_replicate + local_j;
59+
this->get_row_id(col_v) * this->m_block_dim.y + local_j;
5560

5661
return SparseMatrix<T>::operator()(r_id, c_id);
5762
}
@@ -68,9 +73,9 @@ struct HessianSparseMatrix : public SparseMatrix<T>
6873
assert(is_non_zero(row_v, col_v));
6974

7075
const IndexT r_id =
71-
this->get_row_id(row_v) * this->m_replicate + local_i;
76+
this->get_row_id(row_v) * this->m_block_dim.x + local_i;
7277
const IndexT c_id =
73-
this->get_row_id(col_v) * this->m_replicate + local_j;
78+
this->get_row_id(col_v) * this->m_block_dim.y + local_j;
7479

7580
return SparseMatrix<T>::operator()(r_id, c_id);
7681
}
@@ -88,9 +93,9 @@ struct HessianSparseMatrix : public SparseMatrix<T>
8893
assert(is_non_zero(row_v, col_v));
8994

9095
const IndexT r_id =
91-
this->get_row_id(row_v) * this->m_replicate + local_i;
96+
this->get_row_id(row_v) * this->m_block_dim.x + local_i;
9297
const IndexT c_id =
93-
this->get_row_id(col_v) * this->m_replicate + local_j;
98+
this->get_row_id(col_v) * this->m_block_dim.y + local_j;
9499

95100
return {r_id, c_id};
96101
}
@@ -102,8 +107,8 @@ struct HessianSparseMatrix : public SparseMatrix<T>
102107
__device__ __host__ const bool is_non_zero(const VertexHandle& row_v,
103108
const VertexHandle& col_v) const
104109
{
105-
const IndexT r_id = this->get_row_id(row_v) * this->m_replicate;
106-
const IndexT c_id = this->get_row_id(col_v) * this->m_replicate;
110+
const IndexT r_id = this->get_row_id(row_v) * this->m_block_dim.x;
111+
const IndexT c_id = this->get_row_id(col_v) * this->m_block_dim.y;
107112

108113
return SparseMatrix<T>::is_non_zero(r_id, c_id);
109114
}

include/rxmesh/matrix/block_dim.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#pragma once
2+
3+
namespace rxmesh {
4+
namespace detail {
5+
// The dimension of a (dense) block in the sparse matrix. In most cases
6+
// the block size is just 1x1 which represent a single non-zero value.
7+
// But in, e.g., Hessians, the blocks could be kxk.
8+
struct BlockDim
9+
{
10+
int x, y;
11+
12+
BlockDim() : x(1), y(1)
13+
{
14+
}
15+
16+
BlockDim(int x_, int y_) : x(x_), y(y_)
17+
{
18+
}
19+
};
20+
21+
} // namespace detail
22+
} // namespace rxmesh

0 commit comments

Comments
 (0)