@@ -633,16 +633,12 @@ METAL_FUNC void qmv_fast_impl(
633633 constexpr int num_simdgroups = 2 ;
634634 constexpr int results_per_simdgroup = 4 ;
635635 constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
636- constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3 ;
636+ constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3 ;
637637 constexpr int values_per_thread = pack_factor * packs_per_thread;
638638 constexpr int block_size = values_per_thread * SIMD_SIZE;
639639 constexpr int scale_step_per_thread = group_size / values_per_thread;
640640
641- // When bits is a power of two, read 1 uint32_t at a time
642- // When bits is 3 or 6, read 3 uint8_ts at a time
643- using W_T =
644- typename ConditionalType<power_of_2_bits, uint32_t , uint8_t >::type;
645- const device W_T* ws = (const device W_T*)w;
641+ const device uint8_t * ws = (const device uint8_t *)w;
646642
647643 typedef float U;
648644
@@ -705,16 +701,12 @@ METAL_FUNC void qmv_impl(
705701 constexpr int results_per_simdgroup = 4 ;
706702 constexpr int packs_per_thread = 1 ;
707703 constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
708- constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3 ;
704+ constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3 ;
709705 constexpr int values_per_thread = pack_factor * packs_per_thread;
710706 constexpr int block_size = values_per_thread * SIMD_SIZE;
711707 constexpr int scale_step_per_thread = group_size / values_per_thread;
712708
713- // When bits is a power of two, read 1 uint32_t at a time
714- // When bits is 3 or 6, read 3 uint8_ts at a time
715- using W_T =
716- typename ConditionalType<power_of_2_bits, uint32_t , uint8_t >::type;
717- const device W_T* ws = (const device W_T*)w;
709+ const device uint8_t * ws = (const device uint8_t *)w;
718710
719711 typedef float U;
720712
@@ -862,19 +854,15 @@ METAL_FUNC void qvm_impl(
862854 constexpr int power_of_2_bits = (bits & (bits - 1 )) == 0 ;
863855 constexpr int num_simdgroups = 2 ;
864856 constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
865- constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3 ;
857+ constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3 ;
866858 constexpr int tn = 32 / pack_factor;
867859 constexpr int block_size = SIMD_SIZE;
868860
869- // When bits is a power of two, read 1 uint32_t at a time
870- // When bits is 3 or 6, read 3 uint8_ts at a time
871- using W_T =
872- typename ConditionalType<power_of_2_bits, uint32_t , uint8_t >::type;
873- const device W_T* ws = (const device W_T*)w;
861+ const device uint8_t * ws = (const device uint8_t *)w;
874862
875863 typedef float U;
876864 typedef struct {
877- W_T wi[tn * bytes_per_pack];
865+ uint8_t wi[tn * bytes_per_pack];
878866 } vec_w;
879867
880868 thread vec_w w_local;
@@ -2070,9 +2058,7 @@ template <typename T, const int group_size, const int bits>
20702058 }
20712059
20722060 // We accumulate 3 bytes worth for 3/6 bit so we need a uint32_t
2073- using OutT =
2074- typename ConditionalType<power_of_2_bits, uint8_t , uint32_t >::type;
2075- OutT output = 0 ;
2061+ uint32_t output = 0 ;
20762062
20772063#pragma clang loop unroll(full)
20782064 for (int i = 0 ; i < values_per_reduce; i++) {
0 commit comments