@@ -740,21 +740,28 @@ template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC, int DATA_TY
740
740
// __launch_bounds__(TH, 4)
741
741
__global__ void kQuantizeBlockwise (float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n)
742
742
{
743
+ #ifdef BNB_USE_HIP
744
+ const int CUB_NUM_PER_TH=(BLOCK_SIZE/NUM_PER_TH % __AMDGCN_WAVEFRONT_SIZE == 0 ) ? NUM_PER_TH : NUM_PER_TH/2 ;
745
+ #else
746
+ const int CUB_NUM_PER_TH=NUM_PER_TH;
747
+ #endif
748
+ const int DATA_NUM_PER_TH=(DATA_TYPE > 0 ) ? NUM_PER_TH/2 : NUM_PER_TH;
749
+
743
750
const int n_full = gridDim .x * BLOCK_SIZE;
744
751
int valid_items = 0 ;
745
752
const int base_idx = (blockIdx .x * BLOCK_SIZE);
746
753
747
- T vals[NUM_PER_TH ];
748
- float rand_vals[NUM_PER_TH ];
749
- unsigned char qvals[(DATA_TYPE > 0 ) ? NUM_PER_TH/ 2 : NUM_PER_TH ];
754
+ T vals[CUB_NUM_PER_TH ];
755
+ float rand_vals[CUB_NUM_PER_TH ];
756
+ unsigned char qvals[DATA_NUM_PER_TH ];
750
757
// float local_abs_max = -FLT_MAX;
751
758
float local_abs_max = 0 .0f ;
752
759
int local_rand_idx = 0 ;
753
760
754
- typedef cub::BlockLoad<T, BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH , cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
755
- typedef cub::BlockStore<unsigned char , BLOCK_SIZE/NUM_PER_TH, (DATA_TYPE > 0 ) ? NUM_PER_TH/ 2 : NUM_PER_TH , cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
756
- typedef cub::BlockReduce<float , BLOCK_SIZE/NUM_PER_TH > BlockReduce;
757
- typedef cub::BlockLoad<float , BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH , cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
761
+ typedef cub::BlockLoad<T, BLOCK_SIZE/CUB_NUM_PER_TH, CUB_NUM_PER_TH , cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
762
+ typedef cub::BlockStore<unsigned char , BLOCK_SIZE/CUB_NUM_PER_TH, DATA_NUM_PER_TH , cub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
763
+ typedef cub::BlockReduce<float , BLOCK_SIZE/CUB_NUM_PER_TH > BlockReduce;
764
+ typedef cub::BlockLoad<float , BLOCK_SIZE/CUB_NUM_PER_TH, CUB_NUM_PER_TH , cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
758
765
759
766
__shared__ typename LoadT::TempStorage loadt;
760
767
__shared__ typename LoadFloat::TempStorage loadf;
@@ -779,8 +786,8 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
779
786
// 2. broadcast local max
780
787
// 3. normalize inputs and quantize
781
788
782
- #pragma unroll NUM_PER_TH
783
- for (int j = 0 ; j < NUM_PER_TH ; j++)
789
+ #pragma unroll CUB_NUM_PER_TH
790
+ for (int j = 0 ; j < CUB_NUM_PER_TH ; j++)
784
791
local_abs_max = fmaxf (local_abs_max, fabsf ((float )vals[j]));
785
792
786
793
local_abs_max = BlockReduce (reduce).Reduce (local_abs_max, cub::Max (), valid_items);
@@ -809,8 +816,8 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
809
816
switch (DATA_TYPE)
810
817
{
811
818
case General8bit:
812
- #pragma unroll NUM_PER_TH
813
- for (int j = 0 ; j < NUM_PER_TH ; j++)
819
+ #pragma unroll CUB_NUM_PER_TH
820
+ for (int j = 0 ; j < CUB_NUM_PER_TH ; j++)
814
821
{
815
822
if (!STOCHASTIC)
816
823
qvals[j] = dQuantize<0 >(smem_code, 0 .0f , ((float )vals[j])*local_abs_max);
@@ -819,17 +826,17 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
819
826
}
820
827
break ;
821
828
case FP4:
822
- #pragma unroll NUM_PER_TH
823
- for (int j = 0 ; j < NUM_PER_TH/ 2 ; j++)
829
+ #pragma unroll CUB_NUM_PER_TH
830
+ for (int j = 0 ; j < DATA_NUM_PER_TH ; j++)
824
831
{
825
832
packed_4bit |= dQuantizeFP4 (((float )vals[2 *j])*local_abs_max) << 4 ;
826
833
packed_4bit |= dQuantizeFP4 (((float )vals[2 *j+1 ])*local_abs_max);
827
834
qvals[j] = packed_4bit;
828
835
}
829
836
break ;
830
837
case NF4:
831
- #pragma unroll NUM_PER_TH
832
- for (int j = 0 ; j < NUM_PER_TH/ 2 ; j++)
838
+ #pragma unroll CUB_NUM_PER_TH
839
+ for (int j = 0 ; j < DATA_NUM_PER_TH ; j++)
833
840
{
834
841
packed_4bit |= dQuantizeNF4 (((float )vals[2 *j])*local_abs_max) << 4 ;
835
842
packed_4bit |= dQuantizeNF4 (((float )vals[2 *j+1 ])*local_abs_max);
0 commit comments