Skip to content

Commit baa5199

Browse files
committed
helper function to simplify kernel launch
1 parent 52e86ac commit baa5199

File tree

1 file changed

+83
-0
lines changed

1 file changed

+83
-0
lines changed

include/rxmesh/rxmesh_static.h

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,89 @@ class RXMeshStatic : public RXMesh
639639
}
640640
}
641641

642+
643+
/**
644+
* @brief Launching a kernel knowing its launch box
645+
* @tparam ...ArgsT infered
646+
* @tparam blockThreads the block size
647+
* @param lb launch box populated via prepare_launch_box
648+
* @param kernel the kernel to launch
649+
* @param ...args input parameters to the kernel
650+
*/
651+
template <uint32_t blockThreads, typename KernelT, typename... ArgsT>
652+
void run_kernel(const LaunchBox<blockThreads>& lb,
653+
const KernelT kernel,
654+
ArgsT... args) const
655+
{
656+
kernel<<<lb.blocks, lb.num_threads, lb.smem_bytes_dyn>>>(get_context(),
657+
args...);
658+
}
659+
660+
/**
661+
* @brief run a kernel that will require a query operation
662+
* @tparam ...ArgsT infered
663+
* @tparam blockThreads the block size
664+
* @param op list of query operations used inside the kernel
665+
* @param kernel the kernel to run
666+
* @param ...args the inputs to the kernel
667+
*/
668+
template <uint32_t blockThreads, typename KernelT, typename... ArgsT>
669+
void run_kernel(const std::vector<Op> op, KernelT kernel, ArgsT... args)
670+
{
671+
run_kernel<blockThreads>(
672+
kernel,
673+
op,
674+
false,
675+
false,
676+
false,
677+
[](uint32_t v, uint32_t e, uint32_t f) { return 0; },
678+
NULL,
679+
args...);
680+
}
681+
682+
/**
683+
* @brief run a kernel that will require a query operation
684+
* @tparam ...ArgsT infered
685+
* @tparam blockThreads the block size
686+
* @param op list of query operations used inside the kernel
687+
* @param kernel the kernel to run
688+
* @param oriented are the query operation required to be oriented
689+
* @param with_vertex_valence if vertex valence is requested to be
690+
* pre-computed and stored in shared memory
691+
* @param is_concurrent in case of multiple queries (i.e., op.size() > 1),
692+
* this parameter indicates if queries needs to be access at the same time
693+
* @param user_shmem a (lambda) function that takes the number of vertices,
694+
* edges, and faces and returns additional user-desired shared memory in
695+
* bytes. In case no extra shared memory needed, it can be
696+
* [](uint32_t v, uint32_t e, uint32_t f) { return 0; }
697+
* @param ...args the inputs to the kernel
698+
*/
699+
template <uint32_t blockThreads, typename KernelT, typename... ArgsT>
700+
void run_kernel(
701+
const KernelT kernel,
702+
const std::vector<Op> op,
703+
const bool oriented,
704+
const bool with_vertex_valence,
705+
const bool is_concurrent,
706+
std::function<size_t(uint32_t, uint32_t, uint32_t)> user_shmem,
707+
cudaStream_t stream,
708+
ArgsT... args) const
709+
{
710+
LaunchBox<blockThreads> lb;
711+
712+
prepare_launch_box(op,
713+
lb,
714+
kernel,
715+
oriented,
716+
with_vertex_valence,
717+
is_concurrent,
718+
user_shmem);
719+
720+
kernel<<<lb.blocks, lb.num_threads, lb.smem_bytes_dyn, stream>>>(
721+
get_context(), args...);
722+
}
723+
724+
642725
/**
643726
* @brief populate the launch_box with grid size and dynamic shared memory
644727
* needed for kernel launch

0 commit comments

Comments
 (0)