-
Notifications
You must be signed in to change notification settings - Fork 32
Trellis quantization #113
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Trellis quantization #113
Conversation
Using 12 bits per 8 weights I get a better rmse than iq2_xxs. I still need to see how quantizing the group-of-8 scales will affect accuracy. By AVX2 SIMDifying the search for the best code, LLaMA-3.1-8B gets quantized in 130 seconds on the Ryzen-7950X CPU - sluggish but still acceptable.
rmse increases by just 3%, so this is beating iq2_xss in terms of rmse at the same 2.0625 bpw.
I now see that I was comparing apples to oranges: iq2_xxs was using a weight of sigma^2/4 + x^2, while the Trellis approach wasn't (weight = 1). Once I use the same weight, iq2_kt is actually slightly worse than iq2_xxs in terms of rmse, so does not look promising at this point. Also, once each group of 8 Trellis values no longer has a constant sum(q^2) that we can precompute, quantization becomes significantly slower (476 seconds for LLaMA-3.1-8B).
so we can run perplexity calcs. As already indicated by rmse, the 2-bit trellis approach is quite a bit worse than iq2_xxs.
With blocks of 32 and 16 bits per groups of 8 the brute force seach becomes prohibitive in terms of CPU time (30+ minutes for 8B LLaMA after SIMDifying with AVX2). The trick is to group the points in clusters, find the nearest cluster, and only search within the cluster.
Using blocks of 32 and 16 bits per group of 8 weights it beats iq2_xxs in terms of PPL by a significant margin. It is 0.0625 bpw larger, but even if we go to 15 bits per group od 8 (so 0.0625 bpw less than iq2_xxs), PPL is still lower.
Re-quantize after determining block scales (at the epxense of much longer quantization time).
Implemented as DMMV. Very slow - just 81 t/s for LLaMA-3.1-8B. Then again, Q2_K_S with forced to use DMMV only gets 112 t/s vs 145 t/s via MMVQ. My memory is that when the DMMV kernels were properly maintained/used, DMMV was about on par with MMVQ for k-quants on my GPU.
We arrive at 112 t/s.
We arrive at 139 t/s (no FA), and 149 t/s (FA). My RTX-4080 is ~20% slower than the RTX-6000 quoted in the QTIP repository, so with FA (which I'm sure they also used) we are at around ~180 t/s on their GPU, so almost matching their performance.
We arrive at 146 t/s (no FA), and 158 t/s (FA). This is measured for LLaMA-3.1-8B with output.weight left as f16.
3.125 bpw. So far does not look good on the PPL vs bpw plot.
PPL(LLaMA-3.1-8B-Instruct, 8192) is now 6.8322, which is starting to be competitive/slightly better than other quants.
PPL(LLaMA-3.1-8B-Instruct, 8192) is now 6.7892
PPL(LLaMA-3.1-8B-Instruct, 8192) is now 6.7689 after shrinking by 0.015 bpw by using iq4_k instead of q5_k for attn_v.
Nearly 60% improvement of quantization speed by having the points nelonging to a cluster copied to contiguous memory during initialization, and then accessed sequantially while searching for the closest point. LLaMA-3.1-8B now gets quantized in ~150 seconds on the Ryzen-5975WX.
Same trick as last commit applied to iq2_kt. Here we get an even larger speedup: quantization time on the Ryzen-5975WX for LLaMA-3.1-8B drops to 195 seconds from 375 seconds!
Go to groups of 8 for iq3_kt. 2 x 8 = 16 bits for the magnitude plus 1 bpw for the sign. It goves a visible improvement in the PPL vs bpw plot, but that comes at the expense of much longer quantization time (7.5 minutes for LLaMA-3.1-8B on the Ryzen-5975WX). I also notices that the 3INST generator is not actually generating a Gaussian distribution. But going to a better generator means readjusting all the hyper-parameters, so leaving it for later.
ecac9d6
to
3a9926b
Compare
Turboderp was also inspired by QTIP when redoing quantization for their new inference engine found here. There is graphs and more details showing performance of their quants here. I'm interested and will look into it (maybe when the inference engine matures a bit) as I haven't tested using just my 3090 for a 70B model in a long while (the few recent times I wanted to use a 70B I use quants that are too big to fit my 3090 and thus need to be only partially offloaded). |
Note that they did not quantize the embeddings with EXL3, while they might have with GGUF (not sure, still needs verification), and this might affect the perplexity graphs since they did not include the size of that tensor in the graphs. (But since they also untie tied embeddings (to quantize the output tensor), it might be hard to compare fairly depending on the model architecture) Still looks very promising, though! |
The linked doc page says "Accounting for quantization of the output layer can make a huge difference in practice, especially for smaller models. So I am including two versions of each perplexity graph, one with bitrate on the horizontal axis, and one that measures the entire VRAM footprint of the weights (not counting the embedding layer which for most inference tasks can be relegated to system RAM.)" So the bpw chart includes the embeddings layer it seems, and the VRAM one does not (both of which useful so I'm glad they offered both).
Yes. |
I don't like these plots too much. The y-axis needs to be logarithmic, and it needs to be difference to unquantized, not absolute values (else we are chasing differences between possibly different ways of computing perplexity). Also, they massively overemphasize the low bpw range. If you plot on a log scale, you get a more realistic picture. Either way, yes, trellis quantization can bring a 0.1-0.2 bpw reduction in quantized size for the same model quality. But is there any indication of performance? I could get my implementation here to be reasonably performant on CUDA, but expect the CPU implementation to be a disaster performance wise. |
Yes but they are good enough for just looking at a VRAM amount and seeing the expected quality for it with the different quants.
It is more for exllamaV2 to V3 since EXL2 were much worse at low bpw than i-quants. (People did say it did offered better KV cache due to the Hadamard transform added here than llama.cpp even if the model quantization was not as good). Even though the performance on ik_llama.cpp is lower for CUDA I still prefer it to exllamaV2 because of iqk quants (and also the side benefit of one API implementation) when running models that fit solely on my 3090.
Exllama is designed for GPUs (and right now only CUDA with ROCm planned) and they are previewing this alongside a new version of their inference software. The Readme says, "Aside from lifting a few of the most successful features from V2 (such as the generator), ExLlamaV3 is largely rewritten from scratch to provide a cleaner, more modular framework for supporting newer architectures. It also introduces a new SOTA quantization format based on QTIP" "The framework is not yet fully optimized. Performance is lacking, especially on Ampere [...]"
That is unfortunate. |
This is interesting. I have tried Hadamard transforms for model weight quantization because of the claims in the QuIP papers, but I never saw any improvement from it. I haven't tried for KV cache quantization, though. |
Also I forgot to mention it but I did mention your PR to the QTIP authors shortly after you made this draft PR. They said "It seems like they didn't bother making the weights Gaussian first (the IP part of QTIP) before quantizing with a Gaussian codebook (3INST)." You say in the PR "This generates values that are (nearly) normally distributed." and in a commit message "I also notices that the 3INST generator is not actually generating a Gaussian distribution." do you think if you followed the author's suggestion it would result in a meaningful difference in quality or is that something you would expect to not matter as much? (I'm not asking you to implement it if you don't know, I know this PR took a long time, and the fact that it is not CPU friendly means it has limited utility for this repo). |
It depends on what the QTIP authors mean by "they didn't bother making the weights Gaussian first". If they mean that I did not apply a Hadamard transform first, I did try that (QuIP/QuIP#/QTIP they all insist on applying Hadamard transforms to model weights before quantization), but it did not improve the result in any way. The thing about Hadamard transforms and imatrix is that they do not mix well - one needs a special imatrix for that. But I have also tried this, without much success. If they mean that I have missed something in the 3INST implementation, and hence the generated sequence is not normally distributed, and it would be better otherwise, I cannot confirm that either. I did a lot of Monte Carlo stuff in the past, so I know a thing or two about random number sequences. I tried an implementation that produces a perfect Gaussian distribution (and quite a bit more efficiently than theirs), but that made results worse. I was planning to try a sequence that generates quantized values, so CPU inference will be more efficient. But than I started doing other stuff, so that never materialized. But do the QTIP authors believe theirs is much better than what I have done? My impression was that it was about the same, give or take. |
That sounds interesting.
I don't know, the one line I quoted ("It seems ...") is the only thing they said to me. I was merely asking out of my own curiosity, I have no intention of testing their implementation but I may end up testing the EXL3 implementation once it has matured. |
The Hadamard Bros and other people fixated on rotations aren't doing it primarily to improve LLM weight quantization. It's for eliminating downstream outliers in run-time activations + KV-cache so they can successfully quantize those more aggressively down to 4-bits without scrambling model fidelity. Activations and KV-cache are only more sensitive to quantization because of 5-10 tokens per model that represent attention sinks (like [BOS] or "\n") which typically have activation values >100,000x than all the other tokens. This is why even though 4-bit activations only cause ~0.0001% average error, it still breaks most models because the error is all concentrated in these 5-10 essential tokens. This can cause models to glitch out or loop when they're over-quantized. Activation values for attention sinks (outlier tokens) end up very finely-calibrated during training so most models immediately become flakey when they're perturbed. There's another way to resolve this besides submitting to the Hadamard cult. PrefixQuant is a fairly small patch to KV-cache and activation handling that marks the 5-10 largest outlier tokens and just always pre-caches them into KV-cache in full f32. Then 4-bit quantize all the other activations and kv-cache for huge speed and memory benefits and no quality trade-off. |
The author of ExllamaV3 reported that they will attempt other ideas as well and only go back to Hadamard if they don't work better. |
Finally got a chance to read the paper.
Look at "Table 5: Ablation study on quantization techniques used in PrefixQuant" and "Appendix D. More Ablation Results", the blockwise finetune that took 17 hours on Llama-3-70B with an NVIDIA-A100-80GB GPU and it having to be the correct dataset and having all the training parameters exact which contributed to their results.
This still sounds useful they reported this took 13 minutes on Llama-3-70B with an NVIDIA-A100-80GB GPU. "Appendix H. More Visualizations" was really interesting to me. Thanks for the paper link. |
Where does the equation for the optimal R(D) come from? LLaMA-3 requires about ~1 bpw more to achieve the same quantization error compared to other models (see #8). Does this mean that the coding overhead there is < 0.5 bpw? Or does it rather mean that the model weights in LLaMA-3 do contain more information (which is my interpretation)? |
Worst-case model weights can be approximated as maximally unpredictable Gaussian data -- essentially what LLMs might become in the limit once they're trained hard enough to reach 100% entropy levels (a full 8.0 bits per byte) Shannon's rate-distortion function: This is a foundational result from information theory. I applied it in this context because you really are making a code to preserve information as well as possible in lower bitrates. This concept is usually deployed to analyze lossy image or audio compression or to design superior analog channel protocols for multiplexing light better in fiber optic cables. But it applies equally in this setting too and it makes sense in retrospect that this limit would bound the maximum efficiency of any LLM quantization algorithms too. The reason this is so interesting is because usually we use information theory to discuss way more boring distortion proxies like MSE (mean squared error). For an LLM the MSE = (original value - value in quant)^2/(# of parameters). Have you ever investigated what this is for your quants? In any case, I just think it's beautiful seeing LLMs manifest such a clean offset from the optimal-distortion bound on such an abstract ability as being able to faithfully recite Wikipedia. Also, if I rebasis my prior graph to use the optimal distortion bound as the x-axis and scale the y-axis to represent bits, even other quantization methods seem to pretty cleanly establish relatively consistent gaps off the distortion lower bound, with only minor deviations. [Note: There's likely bit-width overhead from non-weight params in EXL3 that I can't account for so this chart may be ~5% more favorable than reality.] And Yes, your coding overhead for Llama-2 was remarkably small and very close to the limit. I grabbed Llama 3 70b and Llama 2 70b to check and entropy in the actual files. It only went up from 6.1 bits per byte --> 6.27 bits per byte. Obviously L3 has more complexity packed into its weights, but in information theoretic terms, there doesn't appear to be a full +1.1 bits per parameter. Maybe that accounts for 30% of the gap? Other 70% may be from Meta engineers just using the full dynamic range of their weights better in L3 vs L2. This undoubtably made training easier for them by making it more stable, but could have had the downstream effect of making the weights harder to quantize for your algorithm, which may have been tuned to expect numeric distributions more similar to L2 weights. Does your new Trellis quant also have a +1.1bit gap between L2 70b and L3 70b? |
Only some recent models are trained at FP8 (such as Deepseek V3/R1), they tend to be BF16, with FP4 training currently in the research stages see this |
Exllama-V3 added cache quantization, turboderp-org/exllamav3@cf84811 They also explain their reasoning in an issue copied below:
|
I have not tried it for 70B models. It is too slow for the amount of patience I have. I know some people are OK spending 2 days quantizing a model on a GPU, but I'm not one of those. |
I'm not sure I can follow. On my book, LLMs only work because there are patterns encoded in the model weights, i.e., the model weights of an LLM are pretty much the opposite of a memoryless signal as required for these equations to hold. We also know that the model weights are definitely not Gaussian, and the so called "outliers" (i.e., weights that do not fall within the expectation of a normal distribution) are more important than the others. Also, the rate distortion equation tells us something about the difference between the signal (model weights) and its approximate representation (quantized model weights), but it tells us nothing about how this will affect observations (predicted token probabilities), which are the result of a complex set of linear and non-linear operations on the signal. |
The latest paper by the bitnet people is literally that: https://arxiv.org/abs/2504.18415 |
The latest quantization hype is
QTIP
- paper, repository. They use a Trellis approach and report impressive results, so I decided to look into this more closely.This PR implements what they call "3INST" in their paper. Basically, if we have a seed
seed
, we generateN
quantized valuesq_i
viawhere
a, b, mask1
andmask2
are suitable constants. This generates values that are (nearly) normally distributed. One uses this to describe a group ofN
quants with a singleL
-bit seed (index). Apart from borrowing the "3INST" algorithm from the QTIP paper, the implementation here has noting else in common with QTIP - there are no Hadamard transforms, and no (tail-biting) Viterbi algorithm is utilized during quantization. Instead, in the usual i- and k-quants style, quants are organized in blocks and super-blocks with suitable block scales, and the search for the best seed during quantization is done via a clustering algorithm.The PR adds 3 new quantization types:
IQ2_KT
:L=16
bits for groups ofN=8
quants. Block size is 32 with a 4-bit block scale, plus a single float scale per tensor row (the 32 bits added by this scale can be safely neglected for typical tensor row sizes), so we end up using 2.125 btwIQ3_KT
:L=12
bits for groups ofN=4
quants. Block size is also 32 with a 4-bit block scale, so 3.125 bpwIQ4_KT
:L=15
bits for groups ofN=4
quants. Blocks of 32 with 8-bit block scales, so 4.0 bpw.Quantization accuracy
This figure shows quantization error
PPL(Q)/PPL(bf16)-1
for LLaMA-3.1-8B-Instruct (context length of 8192 tokens). The blue symbols are k-quants, the black symbols are i-quants, cyan symbols are iqk-quants (not available in mainlinellama.cpp
), and the orange symbols are the Trellis quants added by this PR. We do see a small but noticeable improvement compared to i- and iqk-quants, with about 0.2 fewer bpw required to achieve the same quantization error.How does this compare to the QTIP paper? Unfortunately they report results without fine tuning only for LLaMA-v2. The table shows a comparison between the 2-bit quantizations for LLaMA-v2-7B (the QTIP results are taken from Table 3 in their paper, context length is 4096 tokens)
Although there are small differences between the PPL computed by
llama.cpp
and by the tools used by the QTIP authors, the quantization error as defined above is basically independent of the specifics of the PPL calculation, so we see that the 2 bpw quantization implemented here slightly outperforms QTIP without fine tuning (at the expense of using 0.125 bpw more bits). Given this, and the above graph, my conclusion is that Trellis based quantization is a small improvement compared to i-,k-,iqk-quants, but nowhere near the hype observed around the Internet.Performance
The QTIP authors give TG speed for their 2 bpw variant on an RTX-6000 Ada GPU (see here) and a 7B LLaMA model. My GPU is RTX-4080 (so same generation as theirs, but lower specs). I did a quick attempt to get QTIP going in my environment to have apples-to-apples performance comparison, but it was not successful, so I will use the ratio between their
f16
performance on the RTX-6000 (55.9 t/s) to myfp16
performance on the RTX-4080 (46.2 t/s) to translate QTIP performance on the RTX-6000 (188 t/s) to estimated performance on the RTX-4080:In comparison, I get 194 t/s for
IQ2_KT
(with flash attention enabled, which I assume they also use). These results are with the output tensor left asf16
(which is what is done in QTIP).IQ2_XSS
achieves 208 t/s (output asf16
) or 216 t/s (output asQ5_K
), so QTIP performance is far behind the performance of a model of similar size using a more efficient quantization.Caveats
AVX2
support. The search for the optimum seed is extremely expensive (the QTIP authors say "prohibitive" forL >= 12
without their tail-biting search space reduction), so I had to SIMDify to not have to wait forever for a quantization to finish. This PR being mostly a POC for now, I did not want to spend the time implementing for other instruction sets (or even porting to run on a GPU).AVX2
, quantization is slow - depending on quantization type it takes between 2.5 and 4.5 minutes to quantize LLaMA-3.1-8B on a 32-core Ryzen-5975WX CPU.DMMV
mechanism inllama.cpp
. The algorithm outputs float values, so one needs to convert toint8_t
to use the usual quantized dot products. The cost of this conversion is likely to (more than) offset any advantage one might gain by using SIMDint8_t
dot products.