Skip to content

Commit 62fa8a8

Browse files
committed
Adjust kQuantizeBlockwise to work with WARP size 64
1 parent 153a23d commit 62fa8a8

File tree

2 files changed

+23
-20
lines changed

2 files changed

+23
-20
lines changed

csrc/kernels.cu

+22-15
Original file line numberDiff line numberDiff line change
@@ -740,21 +740,28 @@ template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC, int DATA_TY
740740
//__launch_bounds__(TH, 4)
741741
__global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n)
742742
{
743+
#ifdef BNB_USE_HIP
744+
const int CUB_NUM_PER_TH=(BLOCK_SIZE/NUM_PER_TH % __AMDGCN_WAVEFRONT_SIZE == 0) ? NUM_PER_TH : NUM_PER_TH/2;
745+
#else
746+
const int CUB_NUM_PER_TH=NUM_PER_TH;
747+
#endif
748+
const int DATA_NUM_PER_TH=(DATA_TYPE > 0) ? NUM_PER_TH/2 : NUM_PER_TH;
749+
743750
const int n_full = gridDim.x * BLOCK_SIZE;
744751
int valid_items = 0;
745752
const int base_idx = (blockIdx.x * BLOCK_SIZE);
746753

747-
T vals[NUM_PER_TH];
748-
float rand_vals[NUM_PER_TH];
749-
unsigned char qvals[(DATA_TYPE > 0) ? NUM_PER_TH/2 : NUM_PER_TH];
754+
T vals[CUB_NUM_PER_TH];
755+
float rand_vals[CUB_NUM_PER_TH];
756+
unsigned char qvals[DATA_NUM_PER_TH];
750757
//float local_abs_max = -FLT_MAX;
751758
float local_abs_max = 0.0f;
752759
int local_rand_idx = 0;
753760

754-
typedef cub::BlockLoad<T, BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
755-
typedef cub::BlockStore<unsigned char, BLOCK_SIZE/NUM_PER_TH, (DATA_TYPE > 0) ? NUM_PER_TH/2 : NUM_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
756-
typedef cub::BlockReduce<float, BLOCK_SIZE/NUM_PER_TH> BlockReduce;
757-
typedef cub::BlockLoad<float, BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
761+
typedef cub::BlockLoad<T, BLOCK_SIZE/CUB_NUM_PER_TH, CUB_NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
762+
typedef cub::BlockStore<unsigned char, BLOCK_SIZE/CUB_NUM_PER_TH, DATA_NUM_PER_TH, cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
763+
typedef cub::BlockReduce<float, BLOCK_SIZE/CUB_NUM_PER_TH> BlockReduce;
764+
typedef cub::BlockLoad<float, BLOCK_SIZE/CUB_NUM_PER_TH, CUB_NUM_PER_TH, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
758765

759766
__shared__ typename LoadT::TempStorage loadt;
760767
__shared__ typename LoadFloat::TempStorage loadf;
@@ -779,8 +786,8 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
779786
// 2. broadcast local max
780787
// 3. normalize inputs and quantize
781788

782-
#pragma unroll NUM_PER_TH
783-
for(int j = 0; j < NUM_PER_TH; j++)
789+
#pragma unroll CUB_NUM_PER_TH
790+
for(int j = 0; j < CUB_NUM_PER_TH; j++)
784791
local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j]));
785792

786793
local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, cub::Max(), valid_items);
@@ -809,8 +816,8 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
809816
switch(DATA_TYPE)
810817
{
811818
case General8bit:
812-
#pragma unroll NUM_PER_TH
813-
for(int j = 0; j < NUM_PER_TH; j++)
819+
#pragma unroll CUB_NUM_PER_TH
820+
for(int j = 0; j < CUB_NUM_PER_TH; j++)
814821
{
815822
if(!STOCHASTIC)
816823
qvals[j] = dQuantize<0>(smem_code, 0.0f, ((float)vals[j])*local_abs_max);
@@ -819,17 +826,17 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
819826
}
820827
break;
821828
case FP4:
822-
#pragma unroll NUM_PER_TH
823-
for(int j = 0; j < NUM_PER_TH/2; j++)
829+
#pragma unroll CUB_NUM_PER_TH
830+
for(int j = 0; j < DATA_NUM_PER_TH; j++)
824831
{
825832
packed_4bit |= dQuantizeFP4(((float)vals[2*j])*local_abs_max) << 4;
826833
packed_4bit |= dQuantizeFP4(((float)vals[2*j+1])*local_abs_max);
827834
qvals[j] = packed_4bit;
828835
}
829836
break;
830837
case NF4:
831-
#pragma unroll NUM_PER_TH
832-
for(int j = 0; j < NUM_PER_TH/2; j++)
838+
#pragma unroll CUB_NUM_PER_TH
839+
for(int j = 0; j < DATA_NUM_PER_TH; j++)
833840
{
834841
packed_4bit |= dQuantizeNF4(((float)vals[2*j])*local_abs_max) << 4;
835842
packed_4bit |= dQuantizeNF4(((float)vals[2*j+1])*local_abs_max);

csrc/ops.cuh

+1-5
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,6 @@
1414

1515

1616
#ifdef BNB_USE_HIP
17-
// check rocminfo | grep "Wavefront Size". Should be supported on all new GPU's
18-
// dirty hack to force wavefront_size 32 so this compiles
19-
// RDNA 2 defaults to 64 which conflicts with kQuantizeBlockwise
20-
#define __AMDGCN_WAVEFRONT_SIZE 32
2117

2218
#include <hip/hip_runtime_api.h>
2319
#include <hip/hip_fp16.h>
@@ -58,7 +54,7 @@
5854
#define cublasLtHandle_t hipblasLtHandle_t
5955
#define cublasLtCreate hipblasLtCreate
6056
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
61-
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT //TODO: HIP didn't have the right one, might cause issues
57+
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
6258

6359
#else
6460
#include <cuda_runtime_api.h>

0 commit comments

Comments
 (0)