Skip to content

Trellis quantization #113

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 55 commits into
base: main
Choose a base branch
from
Draft

Trellis quantization #113

wants to merge 55 commits into from

Conversation

ikawrakow
Copy link
Owner

@ikawrakow ikawrakow commented Nov 15, 2024

The latest quantization hype is QTIP - paper, repository. They use a Trellis approach and report impressive results, so I decided to look into this more closely.

This PR implements what they call "3INST" in their paper. Basically, if we have a seed seed, we generate N quantized values q_i via

uint32_t u32;
float16_t * h = reinterpret_cast<float16_t*>(&u32)
for i in 0...N-1
    seed = a * seed + b
    u32 = (mask1 & seed) ^ mask2
    q_i = h[0] + h[1]
end

where a, b, mask1 and mask2 are suitable constants. This generates values that are (nearly) normally distributed. One uses this to describe a group of N quants with a single L-bit seed (index). Apart from borrowing the "3INST" algorithm from the QTIP paper, the implementation here has noting else in common with QTIP - there are no Hadamard transforms, and no (tail-biting) Viterbi algorithm is utilized during quantization. Instead, in the usual i- and k-quants style, quants are organized in blocks and super-blocks with suitable block scales, and the search for the best seed during quantization is done via a clustering algorithm.

The PR adds 3 new quantization types:

  • IQ2_KT: L=16 bits for groups of N=8 quants. Block size is 32 with a 4-bit block scale, plus a single float scale per tensor row (the 32 bits added by this scale can be safely neglected for typical tensor row sizes), so we end up using 2.125 btw
  • IQ3_KT: L=12 bits for groups of N=4 quants. Block size is also 32 with a 4-bit block scale, so 3.125 bpw
  • IQ4_KT: L=15 bits for groups of N=4 quants. Blocks of 32 with 8-bit block scales, so 4.0 bpw.

Quantization accuracy

This figure shows quantization error PPL(Q)/PPL(bf16)-1 for LLaMA-3.1-8B-Instruct (context length of 8192 tokens). The blue symbols are k-quants, the black symbols are i-quants, cyan symbols are iqk-quants (not available in mainline llama.cpp), and the orange symbols are the Trellis quants added by this PR. We do see a small but noticeable improvement compared to i- and iqk-quants, with about 0.2 fewer bpw required to achieve the same quantization error.

il31a

How does this compare to the QTIP paper? Unfortunately they report results without fine tuning only for LLaMA-v2. The table shows a comparison between the 2-bit quantizations for LLaMA-v2-7B (the QTIP results are taken from Table 3 in their paper, context length is 4096 tokens)

Quantization PPL(f16) PPL (Q) Quantization error
QTIP 2 bpw 5.12 6.82 33.2%
IQ2_KT 4.94 6.36 28.7%

Although there are small differences between the PPL computed by llama.cpp and by the tools used by the QTIP authors, the quantization error as defined above is basically independent of the specifics of the PPL calculation, so we see that the 2 bpw quantization implemented here slightly outperforms QTIP without fine tuning (at the expense of using 0.125 bpw more bits). Given this, and the above graph, my conclusion is that Trellis based quantization is a small improvement compared to i-,k-,iqk-quants, but nowhere near the hype observed around the Internet.

Performance

The QTIP authors give TG speed for their 2 bpw variant on an RTX-6000 Ada GPU (see here) and a 7B LLaMA model. My GPU is RTX-4080 (so same generation as theirs, but lower specs). I did a quick attempt to get QTIP going in my environment to have apples-to-apples performance comparison, but it was not successful, so I will use the ratio between their f16 performance on the RTX-6000 (55.9 t/s) to my fp16 performance on the RTX-4080 (46.2 t/s) to translate QTIP performance on the RTX-6000 (188 t/s) to estimated performance on the RTX-4080:

QTIP (2 bpw, RTX-4080) = fp16(RTX-4080)/fp16(RTX-6000) * QTIP (2 bpw, RTX-6000) = 46.2/55.9*188 = 155.4 t/s

In comparison, I get 194 t/s for IQ2_KT (with flash attention enabled, which I assume they also use). These results are with the output tensor left as f16 (which is what is done in QTIP). IQ2_XSS achieves 208 t/s (output as f16) or 216 t/s (output as Q5_K), so QTIP performance is far behind the performance of a model of similar size using a more efficient quantization.

Caveats

  • Quantization is only implemented for a CPU with AVX2 support. The search for the optimum seed is extremely expensive (the QTIP authors say "prohibitive" for L >= 12 without their tail-biting search space reduction), so I had to SIMDify to not have to wait forever for a quantization to finish. This PR being mostly a POC for now, I did not want to spend the time implementing for other instruction sets (or even porting to run on a GPU).
  • Even with AVX2, quantization is slow - depending on quantization type it takes between 2.5 and 4.5 minutes to quantize LLaMA-3.1-8B on a 32-core Ryzen-5975WX CPU.
  • Inference is only implemented on CUDA. Due to the "3INST" algorithm, I expect low performance on the CPU and on the Apple GPU, so did not bother to implement for those.
  • There are no quantized matrix-vector kernels, so implementation is via the DMMV mechanism in llama.cpp. The algorithm outputs float values, so one needs to convert to int8_t to use the usual quantized dot products. The cost of this conversion is likely to (more than) offset any advantage one might gain by using SIMD int8_t dot products.

@ikawrakow ikawrakow marked this pull request as draft November 15, 2024 19:05
Using 12 bits per 8 weights I get a better rmse than
iq2_xxs. I still need to see how quantizing the group-of-8
scales will affect accuracy. By AVX2 SIMDifying the search
for the best code, LLaMA-3.1-8B gets quantized in 130 seconds
on the Ryzen-7950X CPU - sluggish but still acceptable.
rmse increases by just 3%, so this is beating iq2_xss in terms
of rmse at the same 2.0625 bpw.
I now see that I was comparing apples to oranges:
iq2_xxs was using a weight of sigma^2/4 + x^2, while
the Trellis approach wasn't (weight = 1). Once I use the same weight,
iq2_kt is actually slightly worse than iq2_xxs in terms
of rmse, so does not look promising at this point.
Also, once each group of 8 Trellis values no longer has a
constant sum(q^2) that we can precompute, quantization
becomes significantly slower (476 seconds for LLaMA-3.1-8B).
so we can run perplexity calcs.
As already indicated by rmse, the 2-bit trellis approach is
quite a bit worse than iq2_xxs.
With blocks of 32 and 16 bits per groups of 8 the brute force
seach becomes prohibitive in terms of CPU time (30+ minutes
for 8B LLaMA after SIMDifying with AVX2). The trick is to
group the points in clusters, find the nearest cluster,
and only search within the cluster.
Using blocks of 32 and 16 bits per group of 8 weights
it beats iq2_xxs in terms of PPL by a significant margin.
It is 0.0625 bpw larger, but even if we go to 15 bits per
group od 8 (so 0.0625 bpw less than iq2_xxs), PPL is still
lower.
Re-quantize after determining block scales
(at the epxense of much longer quantization time).
Implemented as DMMV.
Very slow - just 81 t/s for LLaMA-3.1-8B.
Then again, Q2_K_S with forced to use DMMV only
gets 112 t/s vs 145 t/s via MMVQ. My memory is that
when the DMMV kernels were properly maintained/used,
DMMV was about on par with MMVQ for k-quants on my GPU.
We arrive at 112 t/s.
We arrive at 139 t/s (no FA), and 149 t/s (FA).

My RTX-4080 is ~20% slower than the RTX-6000 quoted in the
QTIP repository, so with FA (which I'm sure they also used)
we are at around ~180 t/s on their GPU, so almost matching
their performance.
We arrive at 146 t/s (no FA), and 158 t/s (FA).
This is measured for LLaMA-3.1-8B with output.weight
left as f16.
3.125 bpw. So far does not look good on the PPL vs bpw plot.
PPL(LLaMA-3.1-8B-Instruct, 8192) is now 6.8322, which is
starting to be competitive/slightly better than other quants.
PPL(LLaMA-3.1-8B-Instruct, 8192) is now 6.7892
PPL(LLaMA-3.1-8B-Instruct, 8192) is now 6.7689 after shrinking
by 0.015 bpw by using iq4_k instead of q5_k for attn_v.
Nearly 60% improvement of quantization speed by having the
points nelonging to a cluster copied to contiguous memory
during initialization, and then accessed sequantially while
searching for the closest point. LLaMA-3.1-8B now gets
quantized in ~150 seconds on the Ryzen-5975WX.
Same trick as last commit applied to iq2_kt. Here we get
an even larger speedup: quantization time on the Ryzen-5975WX
for LLaMA-3.1-8B drops to 195 seconds from 375 seconds!
Go to groups of 8 for iq3_kt. 2 x 8 = 16 bits for the magnitude
plus 1 bpw for the sign. It goves a visible improvement in the
PPL vs bpw plot, but that comes at the expense of much longer
quantization time (7.5 minutes for LLaMA-3.1-8B on the Ryzen-5975WX).

I also notices that the 3INST generator is not actually generating a
Gaussian distribution. But going to a better generator means
readjusting all the hyper-parameters, so leaving it for later.
@saood06
Copy link
Collaborator

saood06 commented Apr 7, 2025

Turboderp was also inspired by QTIP when redoing quantization for their new inference engine found here.

There is graphs and more details showing performance of their quants here.

I'm interested and will look into it (maybe when the inference engine matures a bit) as I haven't tested using just my 3090 for a 70B model in a long while (the few recent times I wanted to use a 70B I use quants that are too big to fit my 3090 and thus need to be only partially offloaded).

@compilade
Copy link

There is graphs and more details showing performance of their quants here.

Note that they did not quantize the embeddings with EXL3, while they might have with GGUF (not sure, still needs verification), and this might affect the perplexity graphs since they did not include the size of that tensor in the graphs.

(But since they also untie tied embeddings (to quantize the output tensor), it might be hard to compare fairly depending on the model architecture)

Still looks very promising, though!

@saood06
Copy link
Collaborator

saood06 commented Apr 7, 2025

There is graphs and more details showing performance of their quants here.

Note that they did not quantize the embeddings with EXL3, while they might have with GGUF (not sure, still needs verification), and this might affect the perplexity graphs since they did not include the size of that tensor in the graphs.

(But since they also untie tied embeddings (to quantize the output tensor), it might be hard to compare fairly depending on the model architecture)

Still looks very promising, though!

The linked doc page says "Accounting for quantization of the output layer can make a huge difference in practice, especially for smaller models. So I am including two versions of each perplexity graph, one with bitrate on the horizontal axis, and one that measures the entire VRAM footprint of the weights (not counting the embedding layer which for most inference tasks can be relegated to system RAM.)"

So the bpw chart includes the embeddings layer it seems, and the VRAM one does not (both of which useful so I'm glad they offered both).

Still looks very promising, though!

Yes.

@ikawrakow
Copy link
Owner Author

I don't like these plots too much. The y-axis needs to be logarithmic, and it needs to be difference to unquantized, not absolute values (else we are chasing differences between possibly different ways of computing perplexity). Also, they massively overemphasize the low bpw range. If you plot on a log scale, you get a more realistic picture. Either way, yes, trellis quantization can bring a 0.1-0.2 bpw reduction in quantized size for the same model quality. But is there any indication of performance? I could get my implementation here to be reasonably performant on CUDA, but expect the CPU implementation to be a disaster performance wise.

@saood06
Copy link
Collaborator

saood06 commented Apr 7, 2025

I don't like these plots too much. The y-axis needs to be logarithmic, and it needs to be difference to unquantized, not absolute values (else we are chasing differences between possibly different ways of computing perplexity). Also, they massively overemphasize the low bpw range. If you plot on a log scale, you get a more realistic picture.

Yes but they are good enough for just looking at a VRAM amount and seeing the expected quality for it with the different quants.

Either way, yes, trellis quantization can bring a 0.1-0.2 bpw reduction in quantized size for the same model quality.

It is more for exllamaV2 to V3 since EXL2 were much worse at low bpw than i-quants. (People did say it did offered better KV cache due to the Hadamard transform added here than llama.cpp even if the model quantization was not as good).

Even though the performance on ik_llama.cpp is lower for CUDA I still prefer it to exllamaV2 because of iqk quants (and also the side benefit of one API implementation) when running models that fit solely on my 3090.

But is there any indication of performance? I could get my implementation here to be reasonably performant on CUDA, but expect the CPU implementation to be a disaster performance wise.

Exllama is designed for GPUs (and right now only CUDA with ROCm planned) and they are previewing this alongside a new version of their inference software.

The Readme says,

"Aside from lifting a few of the most successful features from V2 (such as the generator), ExLlamaV3 is largely rewritten from scratch to provide a cleaner, more modular framework for supporting newer architectures. It also introduces a new SOTA quantization format based on QTIP"

"The framework is not yet fully optimized. Performance is lacking, especially on Ampere [...]"

but expect the CPU implementation to be a disaster performance wise.

That is unfortunate.

@ikawrakow
Copy link
Owner Author

People did say it did offered better KV cache due to the Hadamard transform added here than llama.cpp even if the model quantization was not as good

This is interesting. I have tried Hadamard transforms for model weight quantization because of the claims in the QuIP papers, but I never saw any improvement from it. I haven't tried for KV cache quantization, though.

@saood06
Copy link
Collaborator

saood06 commented Apr 8, 2025

Also I forgot to mention it but I did mention your PR to the QTIP authors shortly after you made this draft PR. They said "It seems like they didn't bother making the weights Gaussian first (the IP part of QTIP) before quantizing with a Gaussian codebook (3INST)."

You say in the PR "This generates values that are (nearly) normally distributed." and in a commit message "I also notices that the 3INST generator is not actually generating a Gaussian distribution." do you think if you followed the author's suggestion it would result in a meaningful difference in quality or is that something you would expect to not matter as much? (I'm not asking you to implement it if you don't know, I know this PR took a long time, and the fact that it is not CPU friendly means it has limited utility for this repo).

@ikawrakow
Copy link
Owner Author

It depends on what the QTIP authors mean by "they didn't bother making the weights Gaussian first". If they mean that I did not apply a Hadamard transform first, I did try that (QuIP/QuIP#/QTIP they all insist on applying Hadamard transforms to model weights before quantization), but it did not improve the result in any way. The thing about Hadamard transforms and imatrix is that they do not mix well - one needs a special imatrix for that. But I have also tried this, without much success. If they mean that I have missed something in the 3INST implementation, and hence the generated sequence is not normally distributed, and it would be better otherwise, I cannot confirm that either. I did a lot of Monte Carlo stuff in the past, so I know a thing or two about random number sequences. I tried an implementation that produces a perfect Gaussian distribution (and quite a bit more efficiently than theirs), but that made results worse.

I was planning to try a sequence that generates quantized values, so CPU inference will be more efficient. But than I started doing other stuff, so that never materialized.

But do the QTIP authors believe theirs is much better than what I have done? My impression was that it was about the same, give or take.

@saood06
Copy link
Collaborator

saood06 commented Apr 8, 2025

I was planning to try a sequence that generates quantized values, so CPU inference will be more efficient. But than I started doing other stuff, so that never materialized.

That sounds interesting.

It depends on what the QTIP authors mean by ...
...
But do the QTIP authors believe theirs is much better than what I have done? My impression was that it was about the same, give or take.

I don't know, the one line I quoted ("It seems ...") is the only thing they said to me. I was merely asking out of my own curiosity, I have no intention of testing their implementation but I may end up testing the EXL3 implementation once it has matured.

@louiehelm
Copy link

The Hadamard Bros and other people fixated on rotations aren't doing it primarily to improve LLM weight quantization. It's for eliminating downstream outliers in run-time activations + KV-cache so they can successfully quantize those more aggressively down to 4-bits without scrambling model fidelity.

Activations and KV-cache are only more sensitive to quantization because of 5-10 tokens per model that represent attention sinks (like [BOS] or "\n") which typically have activation values >100,000x than all the other tokens. This is why even though 4-bit activations only cause ~0.0001% average error, it still breaks most models because the error is all concentrated in these 5-10 essential tokens. This can cause models to glitch out or loop when they're over-quantized. Activation values for attention sinks (outlier tokens) end up very finely-calibrated during training so most models immediately become flakey when they're perturbed.

There's another way to resolve this besides submitting to the Hadamard cult. PrefixQuant is a fairly small patch to KV-cache and activation handling that marks the 5-10 largest outlier tokens and just always pre-caches them into KV-cache in full f32. Then 4-bit quantize all the other activations and kv-cache for huge speed and memory benefits and no quality trade-off.

@saood06
Copy link
Collaborator

saood06 commented Apr 18, 2025

There's another way to resolve this besides submitting to the Hadamard cult.

The author of ExllamaV3 reported that they will attempt other ideas as well and only go back to Hadamard if they don't work better.

@saood06
Copy link
Collaborator

saood06 commented Apr 19, 2025

PrefixQuant

Finally got a chance to read the paper.

is a fairly small patch

Look at "Table 5: Ablation study on quantization techniques used in PrefixQuant" and "Appendix D. More Ablation Results", the blockwise finetune that took 17 hours on Llama-3-70B with an NVIDIA-A100-80GB GPU and it having to be the correct dataset and having all the training parameters exact which contributed to their results.

KV-cache and activation handling that marks the 5-10 largest outlier tokens and just always pre-caches them into KV-cache in full f32.

This still sounds useful they reported this took 13 minutes on Llama-3-70B with an NVIDIA-A100-80GB GPU.

"Appendix H. More Visualizations" was really interesting to me. Thanks for the paper link.

@louiehelm
Copy link

louiehelm commented Apr 22, 2025

It's fascinating how well your quants track optimal limits from rate-distortion theory.

Optimal D(R) = 2^(-2*bitrate)

ik_graph_with_optimal2

Some of your new quants actually dip down to only ~1.25 bits of overhead.

That's really good considering "optimal" = infinite codebook (which prob hurts t/s)

@ikawrakow
Copy link
Owner Author

Where does the equation for the optimal R(D) come from?

LLaMA-3 requires about ~1 bpw more to achieve the same quantization error compared to other models (see #8). Does this mean that the coding overhead there is < 0.5 bpw? Or does it rather mean that the model weights in LLaMA-3 do contain more information (which is my interpretation)?

@louiehelm
Copy link

louiehelm commented Apr 23, 2025

Worst-case model weights can be approximated as maximally unpredictable Gaussian data -- essentially what LLMs might become in the limit once they're trained hard enough to reach 100% entropy levels (a full 8.0 bits per byte)

Shannon's rate-distortion function:
  R(D) = ½ log₂(σ² / D)
Normalize weights with σ² = 1
  R(D) = ½ log₂(1 / D)
Solving for D as function of rate gives:
  D(R) = 2^(‑2R).

This is a foundational result from information theory. I applied it in this context because you really are making a code to preserve information as well as possible in lower bitrates. This concept is usually deployed to analyze lossy image or audio compression or to design superior analog channel protocols for multiplexing light better in fiber optic cables. But it applies equally in this setting too and it makes sense in retrospect that this limit would bound the maximum efficiency of any LLM quantization algorithms too.

The reason this is so interesting is because usually we use information theory to discuss way more boring distortion proxies like MSE (mean squared error). For an LLM the MSE = (original value - value in quant)^2/(# of parameters). Have you ever investigated what this is for your quants? In any case, I just think it's beautiful seeing LLMs manifest such a clean offset from the optimal-distortion bound on such an abstract ability as being able to faithfully recite Wikipedia.

Also, if I rebasis my prior graph to use the optimal distortion bound as the x-axis and scale the y-axis to represent bits, even other quantization methods seem to pretty cleanly establish relatively consistent gaps off the distortion lower bound, with only minor deviations. [Note: There's likely bit-width overhead from non-weight params in EXL3 that I can't account for so this chart may be ~5% more favorable than reality.]

image

And Yes, your coding overhead for Llama-2 was remarkably small and very close to the limit. I grabbed Llama 3 70b and Llama 2 70b to check and entropy in the actual files. It only went up from 6.1 bits per byte --> 6.27 bits per byte. Obviously L3 has more complexity packed into its weights, but in information theoretic terms, there doesn't appear to be a full +1.1 bits per parameter. Maybe that accounts for 30% of the gap? Other 70% may be from Meta engineers just using the full dynamic range of their weights better in L3 vs L2. This undoubtably made training easier for them by making it more stable, but could have had the downstream effect of making the weights harder to quantize for your algorithm, which may have been tuned to expect numeric distributions more similar to L2 weights. Does your new Trellis quant also have a +1.1bit gap between L2 70b and L3 70b?

@saood06
Copy link
Collaborator

saood06 commented Apr 24, 2025

essentially what LLMs might become in the limit once they're trained hard enough to reach 100% entropy levels (a full 8.0 bits per byte)

Only some recent models are trained at FP8 (such as Deepseek V3/R1), they tend to be BF16, with FP4 training currently in the research stages see this

@saood06
Copy link
Collaborator

saood06 commented Apr 24, 2025

Exllama-V3 added cache quantization,

turboderp-org/exllamav3@cf84811

They also explain their reasoning in an issue copied below:

So cache quantization is implemented now. It's a variant of the same technique used in V2, but now with separate bitrates (2-8 bpw plus 0.5 bpw of overhead) for K and V channels. Works a little better than in V2, and it's more flexible.

I experimented with realtime trellis quantization, learned channel scales, autoencoders and more, but so far with little success, and not enough benefit to justify the overhead and complexity. There's still much to explore, though. For instance, I think it should be possible to learn an optimal rotation for the keys in a given layer, under a quantization constraint, then bake the same transformation into the Q and K projections, preserving their dot product.

But for the time being, it's too much of a side quest, and I need to focus on some other stuff first. In the meantime you can get very usable results from k4v3 quantization, and more-or-less lossless quantization with k5v4. And it's "usable" down to k3v2, depending on the use case. Might make the model more creative or something, who knows (:. I still have to rig up some tests to see if it holds up over long contexts.

@ikawrakow
Copy link
Owner Author

Does your new Trellis quant also have a +1.1bit gap between L2 70b and L3 70b?

I have not tried it for 70B models. It is too slow for the amount of patience I have. I know some people are OK spending 2 days quantizing a model on a GPU, but I'm not one of those.

@ikawrakow
Copy link
Owner Author

Worst-case model weights can be approximated as maximally unpredictable Gaussian data -- essentially what LLMs might become in the limit once they're trained hard enough to reach 100% entropy levels

I'm not sure I can follow. On my book, LLMs only work because there are patterns encoded in the model weights, i.e., the model weights of an LLM are pretty much the opposite of a memoryless signal as required for these equations to hold. We also know that the model weights are definitely not Gaussian, and the so called "outliers" (i.e., weights that do not fall within the expectation of a normal distribution) are more important than the others. Also, the rate distortion equation tells us something about the difference between the signal (model weights) and its approximate representation (quantized model weights), but it tells us nothing about how this will affect observations (predicted token probabilities), which are the result of a complex set of linear and non-linear operations on the signal.

@saood06
Copy link
Collaborator

saood06 commented Apr 28, 2025

The Hadamard Bros and other people fixated on rotations aren't doing it primarily to improve LLM weight quantization. It's for eliminating downstream outliers in run-time activations + KV-cache so they can successfully quantize those more aggressively down to 4-bits without scrambling model fidelity.

The latest paper by the bitnet people is literally that: https://arxiv.org/abs/2504.18415

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants