2929 };
3030#endif
3131
32- inline std::unique_ptr<BlasHandle> g_blas_handle = std::make_unique<BlasHandle>();
32+ inline BlasHandle& get_blas_handle () {
33+ static BlasHandle handle;
34+ return handle;
35+ }
3336
3437template <typename T>
3538void group_gemm_blas (void * A_raw, void * B_raw, void * C_raw,
3639 int64_t * ragged_counts, int num_W, int batch_size, int m, int k, int ragged_inner) {
3740
41+ auto & blas = get_blas_handle ();
3842 T alpha = 1.0 , beta = 0.0 ;
3943 T* A_base = reinterpret_cast <T*>(A_raw);
4044 T* B_base = reinterpret_cast <T*>(B_raw);
@@ -83,7 +87,7 @@ void group_gemm_blas(void* A_raw, void* B_raw, void* C_raw,
8387#ifdef CUDA_BACKEND
8488 cublasStatus_t stat;
8589 if (std::is_same<T, float >::value) {
86- stat = cublasSgemmStridedBatched (g_blas_handle-> handle ,
90+ stat = cublasSgemmStridedBatched (blas. handle ,
8791 transa, transb, M, N, K,
8892 reinterpret_cast <float *>(&alpha),
8993 reinterpret_cast <float *>(A), lda, strideA,
@@ -92,7 +96,7 @@ void group_gemm_blas(void* A_raw, void* B_raw, void* C_raw,
9296 reinterpret_cast <float *>(C), ldc, strideC,
9397 batch_size);
9498 } else if (std::is_same<T, double >::value) {
95- stat = cublasDgemmStridedBatched (g_blas_handle-> handle ,
99+ stat = cublasDgemmStridedBatched (blas. handle ,
96100 transa, transb, M, N, K,
97101 reinterpret_cast <double *>(&alpha),
98102 reinterpret_cast <double *>(A), lda, strideA,
@@ -108,7 +112,7 @@ void group_gemm_blas(void* A_raw, void* B_raw, void* C_raw,
108112#elif defined(HIP_BACKEND)
109113 rocblas_status stat;
110114 if (std::is_same<T, float >::value) {
111- stat = rocblas_sgemm_strided_batched (g_blas_handle-> handle ,
115+ stat = rocblas_sgemm_strided_batched (blas. handle ,
112116 transa, transb, M, N, K,
113117 reinterpret_cast <float *>(&alpha),
114118 reinterpret_cast <float *>(A), lda, strideA,
@@ -117,7 +121,7 @@ void group_gemm_blas(void* A_raw, void* B_raw, void* C_raw,
117121 reinterpret_cast <float *>(C), ldc, strideC,
118122 batch_size);
119123 } else if (std::is_same<T, double >::value) {
120- stat = rocblas_dgemm_strided_batched (g_blas_handle-> handle ,
124+ stat = rocblas_dgemm_strided_batched (blas. handle ,
121125 transa, transb, M, N, K,
122126 reinterpret_cast <double *>(&alpha),
123127 reinterpret_cast <double *>(A), lda, strideA,
0 commit comments