|
| 1 | +#pragma once |
| 2 | + |
| 3 | +#include "rxmesh/diff/hessian_sparse_matrix.h" |
| 4 | + |
| 5 | +#include "rxmesh/diff/scalar.h" |
| 6 | + |
| 7 | +namespace rxmesh { |
| 8 | + |
| 9 | +/** |
| 10 | + * @brief Construct the 'dynamic' sparse Hessian of type T with K variables per |
| 11 | + * vertex. Dynamic here means the sparsity of the Hessian is going to change |
| 12 | + during runtime. The matrix size is (V * K) X (V * K) where V is the number of |
| 13 | + * vertices in the mesh. |
| 14 | + * |
| 15 | + */ |
| 16 | +template <typename T, int K> |
| 17 | +struct DynamicHessianSparseMatrix : public HessianSparseMatrix<T, k> |
| 18 | +{ |
| 19 | + using Type = T; |
| 20 | + |
| 21 | + static constexpr int K_ = K; |
| 22 | + |
| 23 | + using ScalarT = Scalar<T, K, true>; |
| 24 | + |
| 25 | + using IndexT = typename SparseMatrix<T>::IndexT; |
| 26 | + |
| 27 | + |
| 28 | + DynamicHessianSparseMatrix() : HessianSparseMatrix<T, k>() |
| 29 | + { |
| 30 | + } |
| 31 | + |
| 32 | + DynamicHessianSparseMatrix(const RXMeshStatic& rx, Op op = Op::VV) |
| 33 | + : HessianSparseMatrix<T>(rx, op) |
| 34 | + { |
| 35 | + } |
| 36 | + |
| 37 | + |
| 38 | + /** |
| 39 | + * @brief insert more entries to the hessian matrix. The input here is the |
| 40 | + * list of vertices that will be interacting. Thus, we need to extend them |
| 41 | + * by the number of their replicate, i.e., k |
| 42 | + */ |
| 43 | + __host__ void insert(uint32_t size, IndexT* d_rows, IndexT* d_cols) |
| 44 | + { |
| 45 | + // Here, we assume the number of rows and cols is the same and only |
| 46 | + // the sparsity is changing |
| 47 | + |
| 48 | + constexpr uint32_t blockThreads = 256; |
| 49 | + |
| 50 | + // fill the new row_ptr with the data from the mesh connectivity |
| 51 | + rx.run_kernel<blockThreads>( |
| 52 | + {Op::VV}, |
| 53 | + detail::sparse_mat_prescan<Op::VV, blockThreads>, |
| 54 | + m_d_row_ptr, |
| 55 | + k); |
| 56 | + |
| 57 | + // fill the new row_ptr with the data from the new entries |
| 58 | + detail:: |
| 59 | + sparse_mat_prescan<<<DIVIDE_UP(size, blockThreads), blockThreads>>>( |
| 60 | + m_d_row_ptr, size, d_rows, d_cols, k); |
| 61 | + } |
| 62 | + |
| 63 | + // delete the functions that access the matrix using only the VertexHandle |
| 64 | + // since with the Hessian, we should also have the local index (the index |
| 65 | + // within the kxk matrix) |
| 66 | + __device__ __host__ const T& operator()(const VertexHandle& row_v, |
| 67 | + const VertexHandle& col_v) const = |
| 68 | + delete; |
| 69 | + |
| 70 | + __device__ __host__ T& operator()(const VertexHandle& row_v, |
| 71 | + const VertexHandle& col_v) = delete; |
| 72 | +}; |
| 73 | + |
| 74 | +} // namespace rxmesh |
0 commit comments