You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
// we got 32x32 tiles and we use the magic equation from the cublasLt doc to get the element
2401
-
// within each tile.
2402
-
int offset_per_col_tile = ((rowsA+31)/32)*32*32;
2403
-
int tile_offset_rows = (row/32)*32*32;
2404
-
int tile_offset_cols = (local_colidx/32)*offset_per_col_tile;
2405
-
int subtile_col_idx = local_colidx%32;
2406
-
int subtile_row_idx = row % 32;
2407
-
// this magic is taken from the cublasLt doc (search for COL32)
2408
-
int offset = (((subtile_row_idx%8)/2*4+subtile_row_idx/8)*2+subtile_row_idx%2)*32+subtile_col_idx;
2409
-
offset += tile_offset_cols + tile_offset_rows;
2410
-
2411
-
char val = A[offset];
2412
-
int out_idx = (row*idx_size) + blockIdx.x;
2413
-
out[out_idx] = val;
2414
-
}
2415
-
}
2416
-
}
2417
-
2418
2355
#defineWARPS3
2419
2356
template <typename T, int BITS, int THREADS> __global__voidgemm_device(int M, int N, int K, T * __restrict__const A, T* B, T * out, int lda, int ldb, int ldc)
2420
2357
{
@@ -3049,9 +2986,6 @@ template __global__ void kgemm_4bit_inference_naive<half, 128, 16>(int M, int N,
3049
2986
template __global__void kgemm_4bit_inference_naive<__nv_bfloat16, 128, 16>(int M, int N, int K, __nv_bfloat16 * __restrict__const A, unsignedchar *B, float *absmax, constfloat *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize);
3050
2987
template __global__void kgemm_4bit_inference_naive<float, 128, 32>(int M, int N, int K, float * __restrict__const A, unsignedchar *B, float *absmax, constfloat *datatype, float * out, int lda, int ldb, int ldc, int blocksize);
3051
2988
3052
-
template __global__voidkExtractOutliers<COL_TURING>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
3053
-
template __global__voidkExtractOutliers<COL_AMPERE>(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
3054
-
3055
2989
template __global__void kspmm_coo_very_sparse_naive<half, 8, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
3056
2990
template __global__void kspmm_coo_very_sparse_naive<half, 16, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
3057
2991
template __global__void kspmm_coo_very_sparse_naive<half, 32, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
Copy file name to clipboardexpand all lines: csrc/kernels.cuh
-2
Original file line number
Diff line number
Diff line change
@@ -121,8 +121,6 @@ template<typename T, int THREADS, int SPARSE_DECOMP> __global__ void kInt8Vector
121
121
122
122
template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int TRANSPOSE, int FORMAT> __global__voidkTransformRowToFormat(char *__restrict__const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
123
123
124
-
template <int FORMAT> __global__voidkExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA);
125
-
126
124
template <typename T, int BITS, int THREADS> __global__voidgemm_device(int M, int N, int K, T * __restrict__const A, T* B, T * out, int lda, int ldb, int ldc);
127
125
template <typename T, int THREADS> __global__voidkgemm_4bit_inference(int M, int N, int K, T * __restrict__const A, unsignedchar *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize);
128
126
template <typename T, int THREADS, int BITS> __global__voidkgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__const A, unsignedchar *B, float *absmax, constfloat *datatype, T * out, int lda, int ldb, int ldc, int blocksize);
template <typename T> voidgemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits)
587
561
{
588
562
@@ -636,8 +610,6 @@ template void gemm_4bit_inference_naive<float, 32>(int m, int n, int k, float *
636
610
637
611
//template void gemm_host<float>(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits);
638
612
template void gemm_host<half>(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits);
639
-
template void extractOutliers<COL_TURING>(char * A, int *idx, char *out, int idx_size, int rows, int cols);
640
-
template void extractOutliers<COL_AMPERE>(char * A, int *idx, char *out, int idx_size, int rows, int cols);
641
613
642
614
template void spmm_coo_very_sparse_naive<half, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
643
615
template void spmm_coo_very_sparse_naive<signedchar, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signedchar *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
Copy file name to clipboardexpand all lines: csrc/ops.cuh
-2
Original file line number
Diff line number
Diff line change
@@ -182,8 +182,6 @@ void spmm_coo(cusparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_val
182
182
183
183
template <typename T, int BITS> voidspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
184
184
185
-
template <int FORMAT> voidextractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols);
186
-
187
185
voidmatmul4bite(half *A, unsignedchar *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB);
188
186
189
187
template <typename T> voidgemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits);
voiddequantizeBlockwise_bf16_fp4(float *code, unsignedchar *A, float *absmax, __nv_bfloat16 *out, int blocksize, constint n, cudaStream_t stream){ dequantizeBlockwise<__nv_bfloat16, FP4>(NULL, A, absmax, out, blocksize, n, stream); }
150
150
voiddequantizeBlockwise_bf16_nf4(float *code, unsignedchar *A, float *absmax, __nv_bfloat16 *out, int blocksize, constint n, cudaStream_t stream){ dequantizeBlockwise<__nv_bfloat16, NF4>(NULL, A, absmax, out, blocksize, n, stream); }
151
151
152
-
voidextractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers<COL_TURING>(A, idx, out, idx_size, rows, cols); }
153
-
voidextractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers<COL_AMPERE>(A, idx, out, idx_size, rows, cols); }
154
-
155
152
intigemmlt_32(cublasLtHandle_t ltHandle, int m, int n, int k, constint8_t *A, constint8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, cudaStream_t stream) {
156
153
return igemmlt<32, 0>(ltHandle, m, n, k, A, B, C, row_scale, lda, ldb, ldc, stream);
157
154
}
@@ -312,9 +309,6 @@ extern "C"
312
309
voidcspmm_coo_very_sparse_naive_int8(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signedchar *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
voidcextractOutliers_turing(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_turing(A, idx, out, idx_size, rows, cols); }
316
-
voidcextractOutliers_ampere(char * A, int *idx, char *out, int idx_size, int rows, int cols){ extractOutliers_ampere(A, idx, out, idx_size, rows, cols); }
317
-
318
312
//void cgemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc)
319
313
//{ gemm_host_fp32(M, N, K, A, B, out, lda, ldb, ldc); }
0 commit comments