You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardexpand all lines: csrc/kernels.hip
+11-11
Original file line number
Diff line number
Diff line change
@@ -2853,6 +2853,7 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T
2853
2853
#define DENORM 1.0f/127.0f
2854
2854
#define MAX_SPARSE_COUNT 32
2855
2855
#define SMEM_SIZE 8*256
2856
+
#define WARP_SIZE warpSize
2856
2857
template <typename T, int SPMM_ITEMS, int BITS>
2857
2858
__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB)
2858
2859
{
@@ -2873,9 +2874,9 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
const int warp_offset = (warp_id*WARP_SIZE)*SPMM_ITEMS;
2879
2880
const int num_items = BITS == 8 ? 8 : 8;
2880
2881
int idx_col_B = warp_offset;
2881
2882
int local_idx_col_B_offset = 0;
@@ -2895,7 +2896,7 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
2895
2896
}
2896
2897
2897
2898
// each thread processes SPMM_ITEMS=32 per iteration. We have 256 threads. 32*256=x192
2898
-
// we expect each warp to be SPMM_ITEMS*32 apart
2899
+
// we expect each warp to be SPMM_ITEMS*WARP_SIZE apart
2899
2900
// we have a total of 128 bytes for the bank with a bank size of 4 bytes
2900
2901
// added 3 bytes = 6 values between warps should reduce bank conflicts
2901
2902
__shared__ half smem_dequant_stats[SMEM_SIZE];
@@ -3543,7 +3544,6 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
3543
3544
#endif
3544
3545
}
3545
3546
3546
-
#define warp_size __AMDGCN_WAVEFRONT_SIZE
3547
3547
// No of 4bit values processed by each thread
3548
3548
#define num_values_4bit 32
3549
3549
template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize)
@@ -3553,12 +3553,12 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
3553
3553
// load step-by-step in chunks of [warp_size,warps]: 1xwarp_size * [warp_size,warps] -> [1,warps]
3554
3554
// 4 warps -> 4 loads per iter
3555
3555
// 1xwarp_size * warp_sizex4 -> 1x4 outputs per thread block
0 commit comments