@@ -68,7 +68,7 @@ __global__ static void calc_param_loss(
6868 rxmesh::DenseMatrix<T, Eigen::RowMajor> grad,
6969 rxmesh::HessianSparseMatrix<T, VariableDim> hess,
7070 rxmesh::FaceAttribute<T> f_obj_func,
71- bool project_hessian = true )
71+ bool project_hessian)
7272{
7373 using namespace rxmesh ;
7474
@@ -253,9 +253,14 @@ void line_search(
253253 }
254254 });
255255
256- calc_param_loss<T, VariableDim, blockThreads, false >
257- <<<lb.blocks, lb.num_threads, lb.smem_bytes_dyn>>> (
258- rx.get_context (), rest_shape, sol_temp, grad, hess, f_obj_func);
256+ rx.run_kernel (lb,
257+ calc_param_loss<T, VariableDim, blockThreads, false >,
258+ rest_shape,
259+ sol_temp,
260+ grad,
261+ hess,
262+ f_obj_func,
263+ true );
259264
260265 T f_new = reducer.reduce (f_obj_func, cub::Sum (), 0 );
261266
@@ -365,17 +370,13 @@ TEST(DiffAttribute, NewtonMethod)
365370
366371 int num_iterations = 100 ;
367372
368- LaunchBox<blockThreads> lb;
369- rx.prepare_launch_box (
370- {Op::FV}, lb, (void *)calc_rest_shape<T, blockThreads>);
371-
372- calc_rest_shape<T, blockThreads>
373- <<<lb.blocks, lb.num_threads, lb.smem_bytes_dyn>>> (
374- rx.get_context (), coordinates, rest_shape);
373+ rx.run_kernel <blockThreads>(
374+ {Op::FV}, calc_rest_shape<T, blockThreads>, coordinates, rest_shape);
375375
376376 CUDA_ERROR (cudaDeviceSynchronize ());
377377
378378
379+ LaunchBox<blockThreads> lb;
379380 rx.prepare_launch_box (
380381 {Op::FV},
381382 lb,
@@ -398,9 +399,15 @@ TEST(DiffAttribute, NewtonMethod)
398399 for (iter = 0 ; iter < num_iterations; ++iter) {
399400
400401 // 1) calcu objective function
401- calc_param_loss<T, VariableDim, blockThreads, true >
402- <<<lb.blocks, lb.num_threads, lb.smem_bytes_dyn>>> (
403- rx.get_context (), rest_shape, uv, grad, hess, f_obj_func);
402+ rx.run_kernel (lb,
403+ calc_param_loss<T, VariableDim, blockThreads, true >,
404+ rest_shape,
405+ uv,
406+ grad,
407+ hess,
408+ f_obj_func,
409+ true );
410+
404411
405412 T f = reducer.reduce (f_obj_func, cub::Sum (), 0 );
406413
0 commit comments