You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: README.md
+9-11Lines changed: 9 additions & 11 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -33,6 +33,7 @@ Extensive performance results across different bitwidths, batch sizes, and devic
33
33
-[Contributing](#contributing)
34
34
35
35
# Recent Highlights
36
+
- GemLite now supports bfloat16!
36
37
- GemLite is now available in <ahref="https://github.com/vllm-project/vllm/">vllm</a> via the <ahref="https://github.com/mobiusml/hqq/">hqq</a> lib!
37
38
- GemLite is now integrated with <ahref="https://github.com/pytorch/ao">TorchAO</a>/<ahref="https://github.com/sgl-project/sglang">SGLang</a> for 4-bit quantization. Check-out the <ahref="https://pytorch.org/blog/accelerating-llm-inference/">blogpost</a>!
38
39
-**Major performance improvement**: especially on the A100 and H100.
We implement various versions of the Triton kernels:
139
-
* <b><ahref="https://github.com/mobiusml/gemlite/blob/master/gemlite/triton_kernels/gemv_A16fWnO16f_int32packing.py">GEMV</a></b>: This GEMV kernel splits the activations into 1D chunks, performs the dot product using `tl.sum`, and accumulates via atomic addition. It is primarily intended for use with small batch sizes (M < 16). As `tl.atomic_add` does not support bfloat16, this kernel is limited to float16.
136
+
* <b><ahref="https://github.com/mobiusml/gemlite/blob/master/gemlite/triton_kernels/gemv_A16fWnO16f_int32packing.py">GEMV</a></b>: This GEMV kernel splits the activations into 1D chunks, performs the dot product using `tl.sum`, and accumulates via atomic addition. It is primarily intended for use with small batch sizes (M == 1).
140
137
141
138
* <b><ahref="https://github.com/mobiusml/gemlite/blob/master/gemlite/triton_kernels/gemm_A16fWnO16f_int32packing.py">GEMM</a></b>: This GEMM kernel is implemented similarly to <ahref="https://github.com/fpgaminer/GPTQ-triton">GPTQ-triton</a>. Since it uses tensor cores, activations must be padded with zeros along the batch dimension to fit at least 16 rows. It supports both float32 and float16 accumulation for fp16 inputs, but only float32 accumulation for bfloat16.
142
139
143
-
* <b><ahref="https://github.com/mobiusml/gemlite/blob/master/gemlite/triton_kernels/gemm_splitK_A16fWnO16f_int32packing.py">GEMM Split-K</a></b>: This Split-K GEMM kernel is implemented similarly to <ahref="https://github.com/foundation-model-stack/foundation-model-stack/blob/triton/triton/kernels/gptq/splitk_dequant_gemm.py">the gptq Split-K version</a>. We build on the gemm version above and add another dimension in the grid which splits the K dimension into multiple jobs that calculate partial sums, which are atomically added and finally stored. Split-K performs particularly well for batched LLM decoding (batch-size between 1 and 32).
140
+
* <b><ahref="https://github.com/mobiusml/gemlite/blob/master/gemlite/triton_kernels/gemm_splitK_A16fWnO16f_int32packing.py">GEMM Split-K</a></b>: This Split-K GEMM kernel is implemented similarly to <ahref="https://github.com/foundation-model-stack/foundation-model-stack/blob/triton/triton/kernels/gptq/splitk_dequant_gemm.py">the gptq Split-K version</a>. We build on the gemm version above and add another dimension in the grid which splits the K dimension into multiple jobs that calculate partial sums, which are atomically added and finally stored. Split-K performs particularly well for batched LLM decoding (batch-size between 2 and 32).
This newly proposed algorithm in GemLite operates in contrast to the GEMM Split-K approach, but within a GEMV context. By doubling the workload per Triton program launched in the GEMV kernel, it reduces the frequency of loading scales/zeros and lowers the number of threads needed. As a result, this method delivers the best performance for batch-size=1 decoding.
147
144
148
-
All kernels are flexible, supporting 8, 4, 2, and 1-bit weight precisions as well as both fp16 and int8/fp8 activations.
145
+
All kernels are flexible, supporting 8, 4, 2, and 1-bit weight precisions as well as float16, bfloat16 and int8/fp8 activations.
149
146
150
147
## Limitations
151
148
* All kernels require a minimum group-size of 32.
152
-
* The default accumulation DType for FP16 inputs is FP16. If you encounter precision issues, you can try <ahref="https://github.com/mobiusml/gemlite/blob/master/gemlite/core.py#L28">reverting to FP32</a>.
153
149
* <b><ahref="https://github.com/mobiusml/gemlite/blob/master/gemlite/triton_kernels/gemv_revsplitK_A16fWnO16f_int32packing.py">Gemv RevSplit-K</a></b>, which is the default kernel for batch-size=1, does not work with 1-bit weights packed as 32-bit with a group-size of 32. In this case, you should use 8-bit bitpacking via `.pack(...,packing_bitwidth=8)`, or revert to using the `GEMV` kernel instead.
150
+
* On datacenter gpus (A100, H100, H200), 8-bit packing via `gemlite.set_packing_bitwidth(8)` is faster with larger batches.
151
+
*`bfloat16` is about 5-7% slower for `1 <= M <= 64` because of the fp32 fallback atomic addition implementation. You can set the default gemv to the Split-K kernel which could run faster for `M == 1` in some cases depending on the GPU (A100 confirmed, but slower on the H100) `gemlite.core.get_default_gemv = lambda W_nbits: 'GEMM_SPLITK' if (W_nbits < 8) else 'GEMV_SPLITK'`.
0 commit comments