Skip to content

Commit e4fe8b5

Browse files
authored
Merge pull request #50 from ROCm/spmm_naive_warpsize_64
Update spmm naive kernel for warpsize 64
2 parents 4aad810 + 5da9d99 commit e4fe8b5

File tree

2 files changed

+14
-14
lines changed

2 files changed

+14
-14
lines changed

csrc/kernels.hip

+11-11
Original file line numberDiff line numberDiff line change
@@ -2853,6 +2853,7 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T
28532853
#define DENORM 1.0f/127.0f
28542854
#define MAX_SPARSE_COUNT 32
28552855
#define SMEM_SIZE 8*256
2856+
#define WARP_SIZE warpSize
28562857
template <typename T, int SPMM_ITEMS, int BITS>
28572858
__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB)
28582859
{
@@ -2873,9 +2874,9 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
28732874
const int offset = local_max_idx == 0 ? 0 : offset_rowidx[local_max_idx-1];
28742875
const int local_row_idx = rowidx[offset];
28752876

2876-
const int warp_id = threadIdx.x / 32;
2877-
const int warp_idx = threadIdx.x % 32;
2878-
const int warp_offset = (warp_id*32)*SPMM_ITEMS;
2877+
const int warp_id = threadIdx.x / WARP_SIZE;
2878+
const int warp_idx = threadIdx.x % WARP_SIZE;
2879+
const int warp_offset = (warp_id*WARP_SIZE)*SPMM_ITEMS;
28792880
const int num_items = BITS == 8 ? 8 : 8;
28802881
int idx_col_B = warp_offset;
28812882
int local_idx_col_B_offset = 0;
@@ -2895,7 +2896,7 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
28952896
}
28962897

28972898
// each thread processes SPMM_ITEMS=32 per iteration. We have 256 threads. 32*256=x192
2898-
// we expect each warp to be SPMM_ITEMS*32 apart
2899+
// we expect each warp to be SPMM_ITEMS*WARP_SIZE apart
28992900
// we have a total of 128 bytes for the bank with a bank size of 4 bytes
29002901
// added 3 bytes = 6 values between warps should reduce bank conflicts
29012902
__shared__ half smem_dequant_stats[SMEM_SIZE];
@@ -3543,7 +3544,6 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
35433544
#endif
35443545
}
35453546

3546-
#define warp_size __AMDGCN_WAVEFRONT_SIZE
35473547
// No of 4bit values processed by each thread
35483548
#define num_values_4bit 32
35493549
template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize)
@@ -3553,12 +3553,12 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
35533553
// load step-by-step in chunks of [warp_size,warps]: 1xwarp_size * [warp_size,warps] -> [1,warps]
35543554
// 4 warps -> 4 loads per iter
35553555
// 1xwarp_size * warp_sizex4 -> 1x4 outputs per thread block
3556-
typedef hipcub::WarpReduce<float, warp_size> WarpReduce;
3557-
__shared__ typename WarpReduce::TempStorage temp_storage[THREADS/warp_size];
3556+
typedef hipcub::WarpReduce<float, warpSize> WarpReduce;
3557+
__shared__ typename WarpReduce::TempStorage temp_storage[THREADS/warpSize];
35583558

3559-
const int warp_idx = threadIdx.x / warp_size;
3560-
const int warp_lane = threadIdx.x % warp_size;
3561-
const int row_B = (THREADS/warp_size)*blockIdx.x + warp_idx;
3559+
const int warp_idx = threadIdx.x / warpSize;
3560+
const int warp_lane = threadIdx.x % warpSize;
3561+
const int row_B = (THREADS/warpSize)*blockIdx.x + warp_idx;
35623562
const int num_values_8bit = num_values_4bit/2;
35633563
float local_C = 0.0f;
35643564

@@ -3574,7 +3574,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
35743574

35753575
// A: [1, K]
35763576
// B: [M, K]
3577-
for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += warp_size*num_values_4bit)
3577+
for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += warpSize*num_values_4bit)
35783578
{
35793579
int inner_idx_halved = inner_idx/2;
35803580
int offset_B = ldb*row_B;

csrc/ops.hip

+3-3
Original file line numberDiff line numberDiff line change
@@ -904,9 +904,9 @@ template <typename T, int BITS> void gemm_4bit_inference_naive(int m, int n, int
904904
//warpsize - 32
905905
int num_blocks = (m+3)/4;
906906
//warpsize - 64
907-
#if __AMDGCN_WAVEFRONT_SIZE == 64
908-
num_blocks = (m+1)/2;
909-
#endif
907+
if (warpSize == 64) {
908+
num_blocks = (m+1)/2;
909+
}
910910

911911
hipLaunchKernelGGL(( kgemm_4bit_inference_naive<T, 128, BITS>), dim3(num_blocks), dim3(128), 0, 0 , m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize);
912912
CUDA_CHECK_RETURN(hipPeekAtLastError());

0 commit comments

Comments
 (0)