Skip to content

Commit f881f4d

Browse files
committed
use the simplified kernel launch
1 parent baa5199 commit f881f4d

File tree

1 file changed

+21
-14
lines changed

1 file changed

+21
-14
lines changed

tests/RXMesh_test/test_hessian.cu

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ __global__ static void calc_param_loss(
6868
rxmesh::DenseMatrix<T, Eigen::RowMajor> grad,
6969
rxmesh::HessianSparseMatrix<T, VariableDim> hess,
7070
rxmesh::FaceAttribute<T> f_obj_func,
71-
bool project_hessian = true)
71+
bool project_hessian)
7272
{
7373
using namespace rxmesh;
7474

@@ -253,9 +253,14 @@ void line_search(
253253
}
254254
});
255255

256-
calc_param_loss<T, VariableDim, blockThreads, false>
257-
<<<lb.blocks, lb.num_threads, lb.smem_bytes_dyn>>>(
258-
rx.get_context(), rest_shape, sol_temp, grad, hess, f_obj_func);
256+
rx.run_kernel(lb,
257+
calc_param_loss<T, VariableDim, blockThreads, false>,
258+
rest_shape,
259+
sol_temp,
260+
grad,
261+
hess,
262+
f_obj_func,
263+
true);
259264

260265
T f_new = reducer.reduce(f_obj_func, cub::Sum(), 0);
261266

@@ -365,17 +370,13 @@ TEST(DiffAttribute, NewtonMethod)
365370

366371
int num_iterations = 100;
367372

368-
LaunchBox<blockThreads> lb;
369-
rx.prepare_launch_box(
370-
{Op::FV}, lb, (void*)calc_rest_shape<T, blockThreads>);
371-
372-
calc_rest_shape<T, blockThreads>
373-
<<<lb.blocks, lb.num_threads, lb.smem_bytes_dyn>>>(
374-
rx.get_context(), coordinates, rest_shape);
373+
rx.run_kernel<blockThreads>(
374+
{Op::FV}, calc_rest_shape<T, blockThreads>, coordinates, rest_shape);
375375

376376
CUDA_ERROR(cudaDeviceSynchronize());
377377

378378

379+
LaunchBox<blockThreads> lb;
379380
rx.prepare_launch_box(
380381
{Op::FV},
381382
lb,
@@ -398,9 +399,15 @@ TEST(DiffAttribute, NewtonMethod)
398399
for (iter = 0; iter < num_iterations; ++iter) {
399400

400401
// 1) calcu objective function
401-
calc_param_loss<T, VariableDim, blockThreads, true>
402-
<<<lb.blocks, lb.num_threads, lb.smem_bytes_dyn>>>(
403-
rx.get_context(), rest_shape, uv, grad, hess, f_obj_func);
402+
rx.run_kernel(lb,
403+
calc_param_loss<T, VariableDim, blockThreads, true>,
404+
rest_shape,
405+
uv,
406+
grad,
407+
hess,
408+
f_obj_func,
409+
true);
410+
404411

405412
T f = reducer.reduce(f_obj_func, cub::Sum(), 0);
406413

0 commit comments

Comments
 (0)