Skip to content

Commit c57013a

Browse files
committed
Attempted CI bugfix.
1 parent 25abd18 commit c57013a

1 file changed

Lines changed: 9 additions & 5 deletions

File tree

openequivariance/openequivariance/extension/group_mm.hpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,16 @@
2929
};
3030
#endif
3131

32-
inline std::unique_ptr<BlasHandle> g_blas_handle = std::make_unique<BlasHandle>();
32+
inline BlasHandle& get_blas_handle() {
33+
static BlasHandle handle;
34+
return handle;
35+
}
3336

3437
template<typename T>
3538
void group_gemm_blas(void* A_raw, void* B_raw, void* C_raw,
3639
int64_t* ragged_counts, int num_W, int batch_size, int m, int k, int ragged_inner) {
3740

41+
auto& blas = get_blas_handle();
3842
T alpha = 1.0, beta = 0.0;
3943
T* A_base = reinterpret_cast<T*>(A_raw);
4044
T* B_base = reinterpret_cast<T*>(B_raw);
@@ -83,7 +87,7 @@ void group_gemm_blas(void* A_raw, void* B_raw, void* C_raw,
8387
#ifdef CUDA_BACKEND
8488
cublasStatus_t stat;
8589
if (std::is_same<T, float>::value) {
86-
stat = cublasSgemmStridedBatched(g_blas_handle->handle,
90+
stat = cublasSgemmStridedBatched(blas.handle,
8791
transa, transb, M, N, K,
8892
reinterpret_cast<float*>(&alpha),
8993
reinterpret_cast<float*>(A), lda, strideA,
@@ -92,7 +96,7 @@ void group_gemm_blas(void* A_raw, void* B_raw, void* C_raw,
9296
reinterpret_cast<float*>(C), ldc, strideC,
9397
batch_size);
9498
} else if (std::is_same<T, double>::value) {
95-
stat = cublasDgemmStridedBatched(g_blas_handle->handle,
99+
stat = cublasDgemmStridedBatched(blas.handle,
96100
transa, transb, M, N, K,
97101
reinterpret_cast<double*>(&alpha),
98102
reinterpret_cast<double*>(A), lda, strideA,
@@ -108,7 +112,7 @@ void group_gemm_blas(void* A_raw, void* B_raw, void* C_raw,
108112
#elif defined(HIP_BACKEND)
109113
rocblas_status stat;
110114
if (std::is_same<T, float>::value) {
111-
stat = rocblas_sgemm_strided_batched(g_blas_handle->handle,
115+
stat = rocblas_sgemm_strided_batched(blas.handle,
112116
transa, transb, M, N, K,
113117
reinterpret_cast<float*>(&alpha),
114118
reinterpret_cast<float*>(A), lda, strideA,
@@ -117,7 +121,7 @@ void group_gemm_blas(void* A_raw, void* B_raw, void* C_raw,
117121
reinterpret_cast<float*>(C), ldc, strideC,
118122
batch_size);
119123
} else if (std::is_same<T, double>::value) {
120-
stat = rocblas_dgemm_strided_batched(g_blas_handle->handle,
124+
stat = rocblas_dgemm_strided_batched(blas.handle,
121125
transa, transb, M, N, K,
122126
reinterpret_cast<double*>(&alpha),
123127
reinterpret_cast<double*>(A), lda, strideA,

0 commit comments

Comments
 (0)