Skip to content

Commit 6f7986d

Browse files
authored
Cleaner qmv/qvm (#1616)
1 parent 7cbb4ae commit 6f7986d

File tree

2 files changed

+8
-33
lines changed

2 files changed

+8
-33
lines changed

mlx/backend/metal/kernels/quantized.h

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -633,16 +633,12 @@ METAL_FUNC void qmv_fast_impl(
633633
constexpr int num_simdgroups = 2;
634634
constexpr int results_per_simdgroup = 4;
635635
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
636-
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
636+
constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
637637
constexpr int values_per_thread = pack_factor * packs_per_thread;
638638
constexpr int block_size = values_per_thread * SIMD_SIZE;
639639
constexpr int scale_step_per_thread = group_size / values_per_thread;
640640

641-
// When bits is a power of two, read 1 uint32_t at a time
642-
// When bits is 3 or 6, read 3 uint8_ts at a time
643-
using W_T =
644-
typename ConditionalType<power_of_2_bits, uint32_t, uint8_t>::type;
645-
const device W_T* ws = (const device W_T*)w;
641+
const device uint8_t* ws = (const device uint8_t*)w;
646642

647643
typedef float U;
648644

@@ -705,16 +701,12 @@ METAL_FUNC void qmv_impl(
705701
constexpr int results_per_simdgroup = 4;
706702
constexpr int packs_per_thread = 1;
707703
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
708-
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
704+
constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
709705
constexpr int values_per_thread = pack_factor * packs_per_thread;
710706
constexpr int block_size = values_per_thread * SIMD_SIZE;
711707
constexpr int scale_step_per_thread = group_size / values_per_thread;
712708

713-
// When bits is a power of two, read 1 uint32_t at a time
714-
// When bits is 3 or 6, read 3 uint8_ts at a time
715-
using W_T =
716-
typename ConditionalType<power_of_2_bits, uint32_t, uint8_t>::type;
717-
const device W_T* ws = (const device W_T*)w;
709+
const device uint8_t* ws = (const device uint8_t*)w;
718710

719711
typedef float U;
720712

@@ -862,19 +854,15 @@ METAL_FUNC void qvm_impl(
862854
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
863855
constexpr int num_simdgroups = 2;
864856
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
865-
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
857+
constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
866858
constexpr int tn = 32 / pack_factor;
867859
constexpr int block_size = SIMD_SIZE;
868860

869-
// When bits is a power of two, read 1 uint32_t at a time
870-
// When bits is 3 or 6, read 3 uint8_ts at a time
871-
using W_T =
872-
typename ConditionalType<power_of_2_bits, uint32_t, uint8_t>::type;
873-
const device W_T* ws = (const device W_T*)w;
861+
const device uint8_t* ws = (const device uint8_t*)w;
874862

875863
typedef float U;
876864
typedef struct {
877-
W_T wi[tn * bytes_per_pack];
865+
uint8_t wi[tn * bytes_per_pack];
878866
} vec_w;
879867

880868
thread vec_w w_local;
@@ -2070,9 +2058,7 @@ template <typename T, const int group_size, const int bits>
20702058
}
20712059

20722060
// We accumulate 3 bytes worth for 3/6 bit so we need a uint32_t
2073-
using OutT =
2074-
typename ConditionalType<power_of_2_bits, uint8_t, uint32_t>::type;
2075-
OutT output = 0;
2061+
uint32_t output = 0;
20762062

20772063
#pragma clang loop unroll(full)
20782064
for (int i = 0; i < values_per_reduce; i++) {

mlx/backend/metal/kernels/utils.h

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -421,14 +421,3 @@ inline complex64_t simd_shuffle(complex64_t data, uint16_t lane) {
421421
return complex64_t(
422422
simd_shuffle(data.real, lane), simd_shuffle(data.imag, lane));
423423
}
424-
425-
// std::conditional is not included with Metal
426-
template <bool condition, typename T, typename U>
427-
struct ConditionalType {
428-
using type = U;
429-
};
430-
431-
template <typename T, typename U>
432-
struct ConditionalType<true, T, U> {
433-
using type = T;
434-
};

0 commit comments

Comments
 (0)