Skip to content

Commit dc60d0e

Browse files
committed
dyn Hessian
1 parent 1c2c5a2 commit dc60d0e

File tree

3 files changed

+122
-3
lines changed

3 files changed

+122
-3
lines changed

include/rxmesh/context.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,25 @@ class Context
7070
return m_num_vertices;
7171
}
7272

73+
/**
74+
* @brief Total number of vertices in mesh
75+
*/
76+
template <typename HandleT>
77+
__device__ __forceinline__ uint32_t get_num() const
78+
{
79+
if constexpr (std::is_same_v<HandleT, VertexHandle>) {
80+
return m_num_vertices[0];
81+
}
82+
83+
if constexpr (std::is_same_v<HandleT, EdgeHandle>) {
84+
return m_num_edges[0];
85+
}
86+
87+
if constexpr (std::is_same_v<HandleT, FaceHandle>) {
88+
return m_num_faces[0];
89+
}
90+
}
91+
7392
/**
7493
* @brief Total number of patches in mesh
7594
*/
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
#pragma once
2+
3+
#include "rxmesh/diff/hessian_sparse_matrix.h"
4+
5+
#include "rxmesh/diff/scalar.h"
6+
7+
namespace rxmesh {
8+
9+
/**
10+
* @brief Construct the 'dynamic' sparse Hessian of type T with K variables per
11+
* vertex. Dynamic here means the sparsity of the Hessian is going to change
12+
during runtime. The matrix size is (V * K) X (V * K) where V is the number of
13+
* vertices in the mesh.
14+
*
15+
*/
16+
template <typename T, int K>
17+
struct DynamicHessianSparseMatrix : public HessianSparseMatrix<T, k>
18+
{
19+
using Type = T;
20+
21+
static constexpr int K_ = K;
22+
23+
using ScalarT = Scalar<T, K, true>;
24+
25+
using IndexT = typename SparseMatrix<T>::IndexT;
26+
27+
28+
DynamicHessianSparseMatrix() : HessianSparseMatrix<T, k>()
29+
{
30+
}
31+
32+
DynamicHessianSparseMatrix(const RXMeshStatic& rx, Op op = Op::VV)
33+
: HessianSparseMatrix<T>(rx, op)
34+
{
35+
}
36+
37+
38+
/**
39+
* @brief insert more entries to the hessian matrix. The input here is the
40+
* list of vertices that will be interacting. Thus, we need to extend them
41+
* by the number of their replicate, i.e., k
42+
*/
43+
__host__ void insert(uint32_t size, IndexT* d_rows, IndexT* d_cols)
44+
{
45+
// Here, we assume the number of rows and cols is the same and only
46+
// the sparsity is changing
47+
48+
constexpr uint32_t blockThreads = 256;
49+
50+
// fill the new row_ptr with the data from the mesh connectivity
51+
rx.run_kernel<blockThreads>(
52+
{Op::VV},
53+
detail::sparse_mat_prescan<Op::VV, blockThreads>,
54+
m_d_row_ptr,
55+
k);
56+
57+
// fill the new row_ptr with the data from the new entries
58+
detail::
59+
sparse_mat_prescan<<<DIVIDE_UP(size, blockThreads), blockThreads>>>(
60+
m_d_row_ptr, size, d_rows, d_cols, k);
61+
}
62+
63+
// delete the functions that access the matrix using only the VertexHandle
64+
// since with the Hessian, we should also have the local index (the index
65+
// within the kxk matrix)
66+
__device__ __host__ const T& operator()(const VertexHandle& row_v,
67+
const VertexHandle& col_v) const =
68+
delete;
69+
70+
__device__ __host__ T& operator()(const VertexHandle& row_v,
71+
const VertexHandle& col_v) = delete;
72+
};
73+
74+
} // namespace rxmesh

include/rxmesh/matrix/sparse_matrix_kernels.cuh

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ __global__ static void sparse_mat_prescan(const rxmesh::Context context,
1515
IndexT* row_ptr,
1616
IndexT replicate)
1717
{
18+
bool is_aos = true;
19+
1820
using namespace rxmesh;
1921

2022
using HandleT = typename InputHandle<op>::type;
@@ -27,10 +29,17 @@ __global__ static void sparse_mat_prescan(const rxmesh::Context context,
2729
IndexT size = iter.size() + 1;
2830
size *= replicate;
2931
IndexT offset = context.prefix<HandleT>()[patch_id] + local_id;
30-
offset *= replicate;
3132

32-
for (IndexT i = 0; i < replicate; ++i) {
33-
row_ptr[offset + i] = size;
33+
if (is_aos) {
34+
offset *= replicate;
35+
for (IndexT i = 0; i < replicate; ++i) {
36+
row_ptr[offset + i] = size;
37+
}
38+
} else {
39+
const uint32_t num_elements = context.get_num<HandleT>();
40+
for (IndexT i = 0; i < replicate; ++i) {
41+
row_ptr[num_elements * i + offset] = size;
42+
}
3443
}
3544
};
3645

@@ -40,6 +49,23 @@ __global__ static void sparse_mat_prescan(const rxmesh::Context context,
4049
query.dispatch<op>(block, shrd_alloc, init_lambda);
4150
}
4251

52+
template <typename IndexT = int>
53+
__global__ static void sparse_mat_prescan(IndexT* row_ptr,
54+
const IndexT size,
55+
const IndexT* rows,
56+
const IndexT* cols,
57+
const IndexT replicate)
58+
{
59+
const uint32_t tid = threadIdx.x + blockIdx.x * blockDim.x;
60+
if (tid < size) {
61+
const int row = rows[tid] * replicate;
62+
63+
for (IndexT i = 0; i < replicate; ++i) {
64+
::atomicAdd(row_ptr + (row + i), replicate);
65+
}
66+
}
67+
}
68+
4369
template <Op op, uint32_t blockThreads, typename IndexT = int>
4470
__global__ static void sparse_mat_col_fill(const rxmesh::Context context,
4571
IndexT* row_ptr,

0 commit comments

Comments
 (0)