Skip to content

Trellis quantization #113

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 55 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
798f93c
WIP
Kawrakow Nov 5, 2024
c578478
WIP
Kawrakow Nov 5, 2024
afe9db7
WIP
Kawrakow Nov 5, 2024
f21dd3f
Testing Trellis quantization
Kawrakow Nov 5, 2024
9ec1455
Testing Trellis quantization: 4-bit quantized block scales
Kawrakow Nov 5, 2024
f1df1b7
Testing Trellis quantization: playing with scales and generators
Kawrakow Nov 5, 2024
a4f1ac8
iq2_kt: quantize / dequantize
Kawrakow Nov 5, 2024
426a6e6
iq2_kt: CUDA dequantize
Kawrakow Nov 6, 2024
a961a48
WIP
Kawrakow Nov 6, 2024
86948f9
WIP
Kawrakow Nov 6, 2024
766fa60
WIP - try larger blocks
Kawrakow Nov 6, 2024
36e9c92
iq2_kt - this is better
Kawrakow Nov 6, 2024
b3dfe99
iq2_kt - even better
Kawrakow Nov 7, 2024
d2331b9
iq2_kt: CUDA dot product
Kawrakow Nov 7, 2024
aed3910
iq2_kt: very slightly faster CUDA dot product
Kawrakow Nov 7, 2024
b354392
iq2_kt: f16 CUDA dot product
Kawrakow Nov 7, 2024
7cafafc
iq2_kt: faster f16 CUDA dot product
Kawrakow Nov 7, 2024
7bf6e15
iq2_kt: faster f16 CUDA dot product
Kawrakow Nov 7, 2024
590f472
Minor
Kawrakow Nov 7, 2024
4774788
Adding iq3_kt
Kawrakow Nov 7, 2024
977f94b
Forgotten change
Kawrakow Nov 7, 2024
08503ce
WIP
Kawrakow Nov 8, 2024
435eb9b
WIP
Kawrakow Nov 8, 2024
f1fb59b
iq3_kt WIP: slowly improving
Kawrakow Nov 8, 2024
386d139
WIP
Kawrakow Nov 9, 2024
dfcc8a9
iq3_kt WIP: slowly improving
Kawrakow Nov 9, 2024
8f0d075
iq3_kt WIP: slowly improving
Kawrakow Nov 9, 2024
c59830d
iq3_kt WIP: speed up quantization
Kawrakow Nov 10, 2024
e9e5879
iq3_kt speed up quantization
Kawrakow Nov 10, 2024
0ffc9b4
iq3_kt: CUDA dot product
Kawrakow Nov 10, 2024
4608f0c
iq2_kt: SOTA
Kawrakow Nov 10, 2024
47b28c1
iq2_kt: SOTA
Kawrakow Nov 11, 2024
00b4bff
Adding iq4_kt - not competitive at this point
Kawrakow Nov 11, 2024
1d6ca83
WIP
Kawrakow Nov 11, 2024
21ee589
WIP
Kawrakow Nov 11, 2024
e9ced1b
iq4_kt: CUDA dot product
Kawrakow Nov 11, 2024
de7fe92
iq4_kt: minor tweaks
Kawrakow Nov 11, 2024
200a19f
iq2_kt: SOTA
Kawrakow Nov 13, 2024
dbe0854
iq2_kt: SOTA
Kawrakow Nov 13, 2024
215bea5
iq3_kt: small improvements and faster quantization
Kawrakow Nov 13, 2024
4213ab1
iq2_kt: SOTA
Kawrakow Nov 14, 2024
c20b22b
iq3_kt: small progress
Kawrakow Nov 14, 2024
21903f1
WIP
Kawrakow Nov 14, 2024
1be0a9e
iq4_kt: go to 4.0 bpw
Kawrakow Nov 15, 2024
ab1cef3
iq4_kt: very slightly better
Kawrakow Nov 15, 2024
4cf82e7
iq4_kt: failed attemt to adjust CUDA dot product
Kawrakow Nov 15, 2024
e338e0a
DRY
Kawrakow Nov 15, 2024
79565c9
DRY
Kawrakow Nov 15, 2024
81cd220
iq4_kt: CUDA dot product works
Kawrakow Nov 15, 2024
3ee5434
DRY
Kawrakow Nov 15, 2024
5705dc7
Report actual bpw
Kawrakow Nov 15, 2024
2be4cff
Minor tweaks
Kawrakow Nov 18, 2024
3a9926b
Checkpoint
Kawrakow Nov 19, 2024
385a4f5
Merge remote-tracking branch 'origin/main' into ik/try_trellis
Kawrakow Jan 24, 2025
c13027b
Merge remote-tracking branch 'origin/main' into ik/try_trellis
Kawrakow Feb 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion examples/quantize-stats/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,21 @@
set(ARCH_FLAGS "")
if (NOT MSVC)
list(APPEND ARCH_FLAGS -march=native)
endif()
message(STATUS "ARCH_FLAGS = ${ARCH_FLAGS}")
#if (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LWR MATCHES "^(x86_64|i686|amd64|x64|win32)$" OR
# (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND
# CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|i686|AMD64)$"))
# message(STATUS "x86 detected")
# if (NOT MSVC)
# list(APPEND ARCH_FLAGS -march=native)
# endif()
#endif()

add_compile_options("$<$<COMPILE_LANGUAGE:CXX>:${ARCH_FLAGS}>")
set(TARGET llama-quantize-stats)
add_executable(${TARGET} quantize-stats.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE llama build_info ${CMAKE_THREAD_LIBS_INIT})
target_include_directories(${TARGET} PRIVATE ../../common)
target_compile_features(${TARGET} PRIVATE cxx_std_11)
target_compile_features(${TARGET} PRIVATE cxx_std_17)
709 changes: 704 additions & 5 deletions examples/quantize-stats/quantize-stats.cpp

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions examples/quantize/quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
{ "Q2_K_R4", LLAMA_FTYPE_MOSTLY_Q2_K_R4, "Q2_K_S repacked", },
{ "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", },
{ "IQ3_XXS", LLAMA_FTYPE_MOSTLY_IQ3_XXS, " 3.06 bpw quantization", },
{ "IQ3_KT", LLAMA_FTYPE_MOSTLY_IQ3_KT, " 3.125 bpw trellis quantization", },
{ "IQ4_KT", LLAMA_FTYPE_MOSTLY_IQ4_KT, " 4.0 bpw trellis quantization", },
{ "IQ3_XXS_R4",LLAMA_FTYPE_MOSTLY_IQ3_XXS_R4,"IQ3_XXS repacked", },
{ "IQ3_S", LLAMA_FTYPE_MOSTLY_IQ3_S, " 3.44 bpw quantization", },
{ "IQ3_S_R4", LLAMA_FTYPE_MOSTLY_IQ3_S_R4, "IQ3_S repacked", },
Expand All @@ -63,6 +65,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
{ "IQ2_K", LLAMA_FTYPE_MOSTLY_IQ2_K, " 2.375 bpw non-linear quantization",},
{ "IQ2_K_R4", LLAMA_FTYPE_MOSTLY_IQ2_K_R4, "IQ2_K repacked",},
{ "IQ2_KS", LLAMA_FTYPE_MOSTLY_IQ2_KS, " 2.1875 bpw non-linear quantization",},
{ "IQ2_KT", LLAMA_FTYPE_MOSTLY_IQ2_KT, " 2.125 bpw trellis quantization", },
{ "IQ3_K", LLAMA_FTYPE_MOSTLY_IQ3_K, " 3.44 bpw non-linear quantization", },
{ "IQ3_K_R4", LLAMA_FTYPE_MOSTLY_IQ3_K_R4, "IQ3_K repacked", },
{ "IQ3_KL", LLAMA_FTYPE_MOSTLY_IQ3_KL, " 4 bpw non-linear quantization mix",},
Expand Down
6 changes: 6 additions & 0 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,9 @@ extern "C" {
GGML_TYPE_Q8_K32 = 148,
GGML_TYPE_Q8_KR8 = 149,
GGML_TYPE_Q8_K128 = 150,
GGML_TYPE_IQ2_KT = 151,
GGML_TYPE_IQ3_KT = 152,
GGML_TYPE_IQ4_KT = 153,

GGML_TYPE_Q4_0_R8 = 202,
GGML_TYPE_Q5_0_R4 = 206,
Expand Down Expand Up @@ -501,6 +504,9 @@ extern "C" {
GGML_FTYPE_MOSTLY_IQ4_KS = 137, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ2_KS = 138, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ4_KSS = 139, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ2_KT = 140, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ3_KT = 141, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ4_KT = 142, // except 1d tensors
//
GGML_FTYPE_MOSTLY_Q4_0_R8 = 202, // except 1d tensors
GGML_FTYPE_MOSTLY_Q8_0_R8 = 207, // except 1d tensors
Expand Down
18 changes: 18 additions & 0 deletions ggml/src/ggml-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,24 @@ typedef struct {
} block_iq2_ks;
static_assert(sizeof(block_iq2_ks) == sizeof(uint16_t) + QK_K/64 + QK_K/4, "wrong iq2_ks block size/padding");

typedef struct {
uint8_t scales[QK_K/64];
uint8_t ql[QK_K/4];
} block_iq2_kt;
static_assert(sizeof(block_iq2_kt) == QK_K/4 + QK_K/64, "wrong iq2_kt block size/padding");

typedef struct {
uint8_t scales[QK_K/64];
uint8_t ql[QK_K/4];
uint8_t qh[QK_K/8];
} block_iq3_kt;
static_assert(sizeof(block_iq3_kt) == QK_K/4 + QK_K/8 + QK_K/64, "wrong iq3_kt block size/padding");

typedef struct {
uint32_t qs[QK_K/8];
} block_iq4_kt;
static_assert(sizeof(block_iq4_kt) == QK_K/2, "wrong iq4_kt block size/padding");

typedef struct {
ggml_half d;
uint16_t extra;
Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1943,6 +1943,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
&& src0->ne[0] % (GGML_CUDA_DMMV_X*2) == 0 && src1->ne[1] == 1;
bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
&& ggml_cuda_mmvq_type_supported(src0->type)
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
&& src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
bool use_mul_mat_q = ggml_is_quantized(src0->type)
Expand Down Expand Up @@ -2847,6 +2848,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_TYPE_IQ4_KSS:
case GGML_TYPE_IQ2_K:
case GGML_TYPE_IQ2_KS:
case GGML_TYPE_IQ2_KT:
case GGML_TYPE_IQ3_KT:
case GGML_TYPE_IQ4_KT:
case GGML_TYPE_IQ3_K:
case GGML_TYPE_IQ4_K:
case GGML_TYPE_IQ5_K:
Expand Down
7 changes: 7 additions & 0 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ2_KS> {
static constexpr int qi = QI4_XS;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ2_KT> {
static constexpr int qk = QK_K;
static constexpr int qr = QR4_XS;
static constexpr int qi = QI4_XS;
};

template<>
struct ggml_cuda_type_traits<GGML_TYPE_IQ3_K> {
static constexpr int qk = QK_K;
Expand Down
157 changes: 157 additions & 0 deletions ggml/src/ggml-cuda/convert.cu
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,130 @@ static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, ds
for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
}

inline __device__ int nearest_int(float fval) {
assert(fval <= 4194303.f);
float val = fval + 12582912.f;
int i; memcpy(&i, &val, sizeof(int));
return (i & 0x007fffff) - 0x00400000;
}

float __device__ __forceinline__ trellis_next(uint32_t& val) {
constexpr uint32_t ka = 89226354;
constexpr uint32_t kb = 64248484;
constexpr uint32_t kmask = 0x8fff8fff;
constexpr uint32_t km32 = 0x3b603b60;
uint32_t s;
const half * h = (const half *)&s;
val = ka*val + kb;
s = (val & kmask) ^ km32;
//float r = (float)(h[0] +h[1]);
//val = ka*val + kb;
//s = (val & kmask) ^ km32;
//r += (float)(h[0]+h[1]);
//return r;
return (float)(h[0]+h[1]);
}

template<typename dst_t>
static __global__ void dequantize_block_iq2_kt(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) {

int64_t ii = blockIdx.x;
int64_t row = (QK_K * ii) / n_per_row;
const char * cx = (const char *)vx + row * row_size;
float scale = *(const float *)cx;
const block_iq2_kt * x = (const block_iq2_kt *)(cx + sizeof(float));
const int64_t i = ii - (row*n_per_row)/QK_K;

const int64_t tid = threadIdx.x;
const int64_t ib = tid; // 0...31
dst_t * y = yy + ii*QK_K + 8*ib;
const uint16_t * ql = (const uint16_t *)x[i].ql;
uint32_t idx = ql[ib] + 4096;
const float dl = scale * iq4k_values[((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf)] * 31.75f * 1.05f;
for (int j = 0; j < 8; ++j) {
y[j] = dl * trellis_next(idx);
}
}

template<typename dst_t>
static __global__ void dequantize_block_iq3_kt(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) {

int64_t ii = blockIdx.x;
int64_t row = (QK_K * ii) / n_per_row;
const char * cx = (const char *)vx + row * row_size;
float scale = *(const float *)cx;
const block_iq3_kt * x = (const block_iq3_kt *)(cx + sizeof(float));
const int64_t i = ii - (row*n_per_row)/QK_K;

const int64_t tid = threadIdx.x;
const int64_t ib = tid; // 0...31
dst_t * y = yy + ii*QK_K + 8*ib;
const uint16_t * ql = (const uint16_t *)x[i].ql;
uint32_t idx = ql[ib] + 4096;
const float dl = scale * ((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf) * 31.75f * 1.01f; //1.015f;
uint8_t mask = 1 << (ib/4);
for (int j = 0; j < 8; ++j) {
y[j] = dl * std::abs(trellis_next(idx)) * (x[i].qh[(8*ib+j)%32] & mask ? -1.f : 1.f);
}
}

//template<typename dst_t>
//static __global__ void dequantize_block_iq3_kt(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) {
//
// int64_t ii = blockIdx.x;
// int64_t row = (QK_K * ii) / n_per_row;
// const float * dptr = (const float *)((const char *)vx + row * row_size);
// float scale = dptr[0];
// float alpha = dptr[1];
// const block_iq3_kt * x = (const block_iq3_kt *)(dptr + 2);
// const int64_t i = ii - (row*n_per_row)/QK_K;
//
// const int64_t tid = threadIdx.x;
// const int64_t ib = tid; // 0...31
// dst_t * y = yy + ii*QK_K + 8*ib;
// const uint16_t * ql = (const uint16_t *)x[i].ql;
// uint32_t idx = ql[ib] + 4096;
// const float dl = scale * ((x[i].scales[(ib/4)%4] >> 4*(ib/16)) & 0xf) * 31.75f * 1.01f; //1.015f;
// uint8_t mask = 1 << (ib/4);
// for (int j = 0; j < 8; ++j) {
// float ay = std::abs(trellis_next(idx));
// y[j] = dl * ay/(1 - alpha*ay) * (x[i].qh[(8*ib+j)%32] & mask ? -1.f : 1.f);
// }
//}

template<typename dst_t>
static __global__ void dequantize_block_iq4_kt(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t n_per_row, int64_t row_size) {

int64_t ii = blockIdx.x;
int64_t row = (QK_K * ii) / n_per_row;
const float * dptr = (const float *)((const char *)vx + row * row_size);
float scale = dptr[0] * 31.75f * 1.01f;
float row_av = dptr[1];
const block_iq4_kt * x = (const block_iq4_kt *)(dptr + 2);
const int64_t i = ii - (row*n_per_row)/QK_K;

constexpr int kNumGroups = 64;

const int64_t tid = threadIdx.x;
const int64_t ib = tid; // 0...31
dst_t * y = yy + ii*QK_K + 8*ib;
const uint32_t * shb = x[i].qs;
const uint8_t * ql = (const uint8_t *)(shb + 8); //Q::kNblock;
const uint8_t * qh = ql + kNumGroups;
const int ib32 = ib/4;
const int ig = ib%4;
const int jj = ib32*8 + 2*ig;
uint32_t offset = shb[ib32] & 1 ? 4096 + 32768 : 4096;
uint32_t idx1 = ql[jj+0] + ((qh[(jj+0)%(kNumGroups/2)] << (8 - 4*((jj+0)/(kNumGroups/2)))) & 0xf00) + (((shb[ib32] >> (8 + 6*ig+0)) & 7) << 12) + offset;
uint32_t idx2 = ql[jj+1] + ((qh[(jj+1)%(kNumGroups/2)] << (8 - 4*((jj+1)/(kNumGroups/2)))) & 0xf00) + (((shb[ib32] >> (8 + 6*ig+3)) & 7) << 12) + offset;
int ls = ((shb[ib32] & 0xff) >> 1) - 64;
const float dl = scale * ls;
for (int j = 0; j < 4; ++j) {
y[j+0] = dl * trellis_next(idx1) + row_av;
y[j+4] = dl * trellis_next(idx2) + row_av;
}
}

template<typename dst_t>
static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {

Expand Down Expand Up @@ -862,6 +986,27 @@ static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int64_
dequantize_block_iq2_xxs<<<nb, 32, 0, stream>>>(vx, y);
}

template<typename dst_t>
static void dequantize_row_iq2_kt_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int nb = k / QK_K;
dequantize_block_iq2_kt<<<nb, 32, 0, stream>>>(vx, y, n_per_row, ggml_row_size(GGML_TYPE_IQ2_KT, n_per_row));
}

template<typename dst_t>
static void dequantize_row_iq3_kt_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int nb = k / QK_K;
dequantize_block_iq3_kt<<<nb, 32, 0, stream>>>(vx, y, n_per_row, ggml_row_size(GGML_TYPE_IQ3_KT, n_per_row));
}

template<typename dst_t>
static void dequantize_row_iq4_kt_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
const int nb = k / QK_K;
dequantize_block_iq4_kt<<<nb, 32, 0, stream>>>(vx, y, n_per_row, ggml_row_size(GGML_TYPE_IQ4_KT, n_per_row));
}

template<typename dst_t>
static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
const int64_t k = nrows * n_per_row;
Expand Down Expand Up @@ -1098,6 +1243,12 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
return dequantize_row_q6_K_cuda;
case GGML_TYPE_IQ2_XXS:
return dequantize_row_iq2_xxs_cuda;
case GGML_TYPE_IQ2_KT:
return dequantize_row_iq2_kt_cuda;
case GGML_TYPE_IQ3_KT:
return dequantize_row_iq3_kt_cuda;
case GGML_TYPE_IQ4_KT:
return dequantize_row_iq4_kt_cuda;
case GGML_TYPE_IQ2_XS:
return dequantize_row_iq2_xs_cuda;
case GGML_TYPE_IQ2_S:
Expand Down Expand Up @@ -1169,6 +1320,12 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
return dequantize_row_q6_K_cuda;
case GGML_TYPE_IQ2_XXS:
return dequantize_row_iq2_xxs_cuda;
case GGML_TYPE_IQ2_KT:
return dequantize_row_iq2_kt_cuda;
case GGML_TYPE_IQ3_KT:
return dequantize_row_iq3_kt_cuda;
case GGML_TYPE_IQ4_KT:
return dequantize_row_iq4_kt_cuda;
case GGML_TYPE_IQ2_XS:
return dequantize_row_iq2_xs_cuda;
case GGML_TYPE_IQ2_S:
Expand Down
Loading