You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: csrc/kernels.hip
+11-11Lines changed: 11 additions & 11 deletions
Original file line number
Diff line number
Diff line change
@@ -2853,6 +2853,7 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T
2853
2853
#defineDENORM1.0f/127.0f
2854
2854
#defineMAX_SPARSE_COUNT32
2855
2855
#defineSMEM_SIZE8*256
2856
+
#defineWARP_SIZE warpSize
2856
2857
template <typename T, int SPMM_ITEMS, int BITS>
2857
2858
__global__ voidkspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB)
2858
2859
{
@@ -2873,9 +2874,9 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
@@ -2895,7 +2896,7 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
2895
2896
}
2896
2897
2897
2898
// each thread processes SPMM_ITEMS=32 per iteration. We have 256 threads. 32*256=x192
2898
-
// we expect each warp to be SPMM_ITEMS*32 apart
2899
+
// we expect each warp to be SPMM_ITEMS*WARP_SIZE apart
2899
2900
// we have a total of 128 bytes for the bank with a bank size of 4 bytes
2900
2901
// added 3 bytes = 6 values between warps should reduce bank conflicts
2901
2902
__shared__ half smem_dequant_stats[SMEM_SIZE];
@@ -3543,7 +3544,6 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
3543
3544
#endif
3544
3545
}
3545
3546
3546
-
#definewarp_size __AMDGCN_WAVEFRONT_SIZE
3547
3547
// No of 4bit values processed by each thread
3548
3548
#definenum_values_4bit32
3549
3549
template <typename T, int THREADS, int BITS> __global__ voidkgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsignedchar *B, float *absmax, constfloat *datatype, T * out, int lda, int ldb, int ldc, int blocksize)
@@ -3553,12 +3553,12 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
3553
3553
// load step-by-step in chunks of [warp_size,warps]: 1xwarp_size * [warp_size,warps] -> [1,warps]
3554
3554
// 4 warps -> 4 loads per iter
3555
3555
// 1xwarp_size * warp_sizex4 -> 1x4 outputs per thread block
0 commit comments