Skip to content

Commit 87d66b1

Browse files
authored
Merge pull request #24 from mobiusml/bfloat16
Add bfloat16 support to gemlite kernels
2 parents 43c1c2c + 2e002fa commit 87d66b1

14 files changed

Lines changed: 218 additions & 143 deletions

README.md

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ Extensive performance results across different bitwidths, batch sizes, and devic
3333
- [Contributing](#contributing)
3434

3535
# Recent Highlights
36+
- GemLite now supports bfloat16!
3637
- GemLite is now available in <a href="https://github.com/vllm-project/vllm/">vllm</a> via the <a href="https://github.com/mobiusml/hqq/">hqq</a> lib!
3738
- GemLite is now integrated with <a href="https://github.com/pytorch/ao">TorchAO</a>/<a href="https://github.com/sgl-project/sglang">SGLang</a> for 4-bit quantization. Check-out the <a href="https://pytorch.org/blog/accelerating-llm-inference/">blogpost</a>!
3839
- **Major performance improvement**: especially on the A100 and H100.
@@ -61,16 +62,9 @@ pip install git+https://github.com/mobiusml/gemlite/
6162
import gemlite
6263
from gemlite import DType, GemLiteLinear
6364

64-
#Set accumulation dtype (only do this once)
65-
#gemlite.set_acc_dtype(DType.FP32) #For A100/H100 (default)
66-
#gemlite.set_acc_dtype(DType.FP16) #For 3090/4090 (default)
67-
6865
#Set default packing bitwidth: use 8-bit for larger batch-sizes on A100s/H100s
6966
#gemlite.set_packing_bitwidth(8)
7067

71-
#Set autotune (by default uses powers of 2 up to 1024)
72-
#gemlite.set_autotune_setting(lambda M: M) #max-autotune example
73-
7468
#Main constructor
7569
gemlite_linear = GemLiteLinear(
7670
W_nbits, #weight quantization bitwidth. supported: [8, 4, 2, 1]
@@ -124,6 +118,9 @@ import gemlite
124118
#Ignore pre-loaded configs - if you want to start from scratch (Optional)
125119
#gemlite.reset_config()
126120

121+
#Set autotune (by default uses powers of 2 up to 1024)
122+
#gemlite.set_autotune_setting(lambda M: M) #max-autotune example
123+
127124
#Warm-up for A16W4 with group_size=64
128125
gemlite.helper.warmup(shapes=[(4096, 4096)], W_nbits=[4], group_sizes=[64], mode='static')
129126

@@ -136,21 +133,22 @@ gemlite.cache_config('new_config.json')
136133

137134
## Deep Dive
138135
We implement various versions of the Triton kernels:
139-
* <b><a href="https://github.com/mobiusml/gemlite/blob/master/gemlite/triton_kernels/gemv_A16fWnO16f_int32packing.py">GEMV</a></b>: This GEMV kernel splits the activations into 1D chunks, performs the dot product using `tl.sum`, and accumulates via atomic addition. It is primarily intended for use with small batch sizes (M < 16). As `tl.atomic_add` does not support bfloat16, this kernel is limited to float16.
136+
* <b><a href="https://github.com/mobiusml/gemlite/blob/master/gemlite/triton_kernels/gemv_A16fWnO16f_int32packing.py">GEMV</a></b>: This GEMV kernel splits the activations into 1D chunks, performs the dot product using `tl.sum`, and accumulates via atomic addition. It is primarily intended for use with small batch sizes (M == 1).
140137

141138
* <b><a href="https://github.com/mobiusml/gemlite/blob/master/gemlite/triton_kernels/gemm_A16fWnO16f_int32packing.py">GEMM</a></b>: This GEMM kernel is implemented similarly to <a href="https://github.com/fpgaminer/GPTQ-triton">GPTQ-triton</a>. Since it uses tensor cores, activations must be padded with zeros along the batch dimension to fit at least 16 rows. It supports both float32 and float16 accumulation for fp16 inputs, but only float32 accumulation for bfloat16.
142139

143-
* <b><a href="https://github.com/mobiusml/gemlite/blob/master/gemlite/triton_kernels/gemm_splitK_A16fWnO16f_int32packing.py">GEMM Split-K</a></b>: This Split-K GEMM kernel is implemented similarly to <a href="https://github.com/foundation-model-stack/foundation-model-stack/blob/triton/triton/kernels/gptq/splitk_dequant_gemm.py">the gptq Split-K version</a>. We build on the gemm version above and add another dimension in the grid which splits the K dimension into multiple jobs that calculate partial sums, which are atomically added and finally stored. Split-K performs particularly well for batched LLM decoding (batch-size between 1 and 32).
140+
* <b><a href="https://github.com/mobiusml/gemlite/blob/master/gemlite/triton_kernels/gemm_splitK_A16fWnO16f_int32packing.py">GEMM Split-K</a></b>: This Split-K GEMM kernel is implemented similarly to <a href="https://github.com/foundation-model-stack/foundation-model-stack/blob/triton/triton/kernels/gptq/splitk_dequant_gemm.py">the gptq Split-K version</a>. We build on the gemm version above and add another dimension in the grid which splits the K dimension into multiple jobs that calculate partial sums, which are atomically added and finally stored. Split-K performs particularly well for batched LLM decoding (batch-size between 2 and 32).
144141

145142
* <b><a href="https://github.com/mobiusml/gemlite/blob/master/gemlite/triton_kernels/gemv_revsplitK_A16fWnO16f_int32packing.py">Gemv RevSplit-K</a></b>:
146143
This newly proposed algorithm in GemLite operates in contrast to the GEMM Split-K approach, but within a GEMV context. By doubling the workload per Triton program launched in the GEMV kernel, it reduces the frequency of loading scales/zeros and lowers the number of threads needed. As a result, this method delivers the best performance for batch-size=1 decoding.
147144

148-
All kernels are flexible, supporting 8, 4, 2, and 1-bit weight precisions as well as both fp16 and int8/fp8 activations.
145+
All kernels are flexible, supporting 8, 4, 2, and 1-bit weight precisions as well as float16, bfloat16 and int8/fp8 activations.
149146

150147
## Limitations
151148
* All kernels require a minimum group-size of 32.
152-
* The default accumulation DType for FP16 inputs is FP16. If you encounter precision issues, you can try <a href="https://github.com/mobiusml/gemlite/blob/master/gemlite/core.py#L28">reverting to FP32</a>.
153149
* <b><a href="https://github.com/mobiusml/gemlite/blob/master/gemlite/triton_kernels/gemv_revsplitK_A16fWnO16f_int32packing.py">Gemv RevSplit-K</a></b>, which is the default kernel for batch-size=1, does not work with 1-bit weights packed as 32-bit with a group-size of 32. In this case, you should use 8-bit bitpacking via `.pack(...,packing_bitwidth=8)`, or revert to using the `GEMV` kernel instead.
150+
* On datacenter gpus (A100, H100, H200), 8-bit packing via `gemlite.set_packing_bitwidth(8)` is faster with larger batches.
151+
* `bfloat16` is about 5-7% slower for `1 <= M <= 64` because of the fp32 fallback atomic addition implementation. You can set the default gemv to the Split-K kernel which could run faster for `M == 1` in some cases depending on the GPU (A100 confirmed, but slower on the H100) `gemlite.core.get_default_gemv = lambda W_nbits: 'GEMM_SPLITK' if (W_nbits < 8) else 'GEMV_SPLITK'`.
154152

155153
## Performance
156154
### End-2-End Performance

examples/benchmark_triton.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,6 @@
1919
from gemlite.core import GemLiteLinearTriton, DType, set_autotune, GEMLITE_ACC_DTYPE
2020
set_autotune({'GEMV_REVSPLITK':True, 'GEMV_SPLITK': True, 'GEMV':True, 'GEMM_SPLITK':True, 'GEMM':True}, exhaustive=True, use_cuda_graph=False)
2121

22-
GEMLITE_ACC_DTYPE[DType.FP16] = DType.FP32 #For A100/H100
23-
#GEMLITE_ACC_DTYPE[DType.FP16] = DType.FP16 #For 3090/4090
24-
2522
device = 'cuda:0'
2623
compute_dtype = torch.float16
2724

examples/triton_hqq_example.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,26 +8,28 @@ def check_valid(x, W, quant_linear, tol=1e-3):
88
############################################################################################
99
from hqq.core.quantize import HQQLinear, BaseQuantizeConfig
1010

11-
in_features, out_features = 4096*4, 4096*2
11+
in_features, out_features = 4096*4, 4096*4
1212
#W_nbits, group_size = 8, in_features
13-
W_nbits, group_size = 4, 128
14-
#W_nbits, group_size = 2, 128
13+
W_nbits, group_size = 4, 64
14+
#W_nbits, group_size = 2, 64
15+
compute_dtype = torch.float16 #float16 / bfloat16
1516

1617
linear = torch.nn.Linear(in_features=in_features, out_features=out_features, bias=False, device='cpu')
1718
quant_config = BaseQuantizeConfig(nbits=W_nbits, group_size=group_size, quant_zero=False, quant_scale=False, axis=1)
18-
hqq_layer = HQQLinear(linear, quant_config=quant_config, compute_dtype=torch.float16, device='cuda:0', del_orig=False)
19+
hqq_layer = HQQLinear(linear, quant_config=quant_config, compute_dtype=compute_dtype, device='cuda:0', del_orig=False)
1920

2021
orig_shape = (out_features, in_features)
2122
W = hqq_layer.dequantize().reshape(orig_shape)
2223
############################################################################################
2324

24-
from gemlite.core import GemLiteLinearTriton, DType
25+
from gemlite.core import GemLiteLinearTriton, DType, TORCH_TO_DTYPE
26+
gemlite_dtype = TORCH_TO_DTYPE[compute_dtype]
2527
gemlite_linear = GemLiteLinearTriton(W_nbits,
2628
group_size=group_size,
2729
in_features=in_features,
2830
out_features=out_features,
29-
input_dtype=DType.FP16,
30-
output_dtype=DType.FP16)
31+
input_dtype=gemlite_dtype,
32+
output_dtype=gemlite_dtype)
3133

3234
W_q = hqq_layer.unpack(dtype=torch.uint8).view(orig_shape)
3335
scales = hqq_layer.meta['scale']

gemlite/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.4.3"
1+
__version__ = "0.4.4"
22
__author__ = 'Dr. Hicham Badri'
33
__credits__ = 'Mobius Labs GmbH'
44

gemlite/core.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,23 @@
2828
###################################################################################################################################
2929
# Triton backend
3030
###################################################################################################################################
31-
GEMLITE_ACC_DTYPE = {DType.FP16: DType.FP32 if gpu_has_more_shared_memory() else DType.FP16, DType.FP8: DType.FP32, DType.FP8e5: DType.FP32, DType.INT8: DType.INT32}
32-
GEMLITE_TRITON_KERNELS = [gemv_A16fWnO16f, gemv_revsplitK_A16fWnO16f, gemv_splitK_A16fWnO16f, gemm_splitK_A16fWnO16f, gemm_A16fWnO16f]
31+
GEMLITE_ACC_DTYPE = {
32+
DType.FP16: DType.FP32 if gpu_has_more_shared_memory() else DType.FP16,
33+
DType.BF16: DType.FP32,
34+
DType.FP32: DType.FP32,
35+
DType.FP8: DType.FP32,
36+
DType.FP8e5: DType.FP32,
37+
DType.INT8: DType.INT32,
38+
}
39+
40+
GEMLITE_TRITON_KERNELS = [
41+
gemv_A16fWnO16f,
42+
gemv_revsplitK_A16fWnO16f,
43+
gemv_splitK_A16fWnO16f,
44+
gemm_splitK_A16fWnO16f,
45+
gemm_A16fWnO16f,
46+
]
47+
3348
GEMLITE_TRITON_MAPPING = {kernel.matmul_type : kernel for kernel in GEMLITE_TRITON_KERNELS}
3449
GEMLITE_TRITON_CONFIG_CACHE = {} #Global config cache for all the kernels
3550
GEMLITE_TRITON_CACHE = {} #Cache used forward with warmup
@@ -94,11 +109,14 @@ def set_acc_dtype(dtype):
94109
assert dtype in [DType.FP16, DType.FP32], "Invalid dtype (should be DType.FP16 or DType.FP32)."
95110
GEMLITE_ACC_DTYPE[DType.FP16] = dtype
96111

112+
#Return the default gemv kernel to use for M==1
113+
def get_default_gemv(W_nbits: int) -> str:
114+
return 'GEMV_REVSPLITK' if (W_nbits < 8) else 'GEMV_SPLITK'
97115
###################################################################################################################################
98116
#Main class
99117
class GemLiteLinearTriton(torch.nn.Module):
100118
SUPPORTED_BITS_TRITON = [1, 2, 4, 8, 16]
101-
SUPPORTED_DTYPES = [DType.FP16, DType.FP8, DType.FP8e5, DType.INT8]
119+
SUPPORTED_DTYPES = [DType.FP16, DType.BF16, DType.FP32, DType.FP8, DType.FP8e5, DType.INT8]
102120
MIN_SIZE = 64
103121
PACKING_BITWIDTH = 32 #Default packing bitwidth
104122

@@ -144,8 +162,8 @@ def __init__(
144162

145163
self.input_dtype = input_dtype
146164
self.output_dtype = output_dtype
147-
self.compute_dtype = torch.float16
148-
self.meta_dtype = DType.FP16
165+
self.compute_dtype = DTYPE_TO_TORCH[input_dtype.value]
166+
self.meta_dtype = input_dtype
149167
self.kernels = GEMLITE_TRITON_KERNELS
150168

151169
#Accumulation
@@ -161,18 +179,14 @@ def __init__(
161179
self.forward = self.forward_auto_no_warmup
162180

163181
#Default GEMV for packed vs. non-packed data
164-
self.default_gemv = self.get_default_gemv()
182+
self.default_gemv = get_default_gemv(self.W_nbits)
165183

166184
#Set torch flags
167185
try:
168186
torch._dynamo.config.inline_inbuilt_nn_modules = False #2.5.0 fix
169187
except:
170188
pass
171189

172-
#Returns the default gemv choice based on the config
173-
def get_default_gemv(self):
174-
return 'GEMV_REVSPLITK' if (self.W_nbits < 8) else 'GEMV_SPLITK'
175-
176190
#Override this function to perform dynamic activation quantization
177191
def scale_activations(self, x: Tensor) -> Tuple[Tensor, Tensor]:
178192
return x, self.scales_x
@@ -188,8 +202,8 @@ def pack(self, W_q: Tensor, scales: Tensor, zeros: Union[Tensor, int], bias: Uni
188202

189203
#Unpacked weights
190204
self.W_q = None
191-
if(W_q.dtype in [torch.float16, torch.int8, torch.float8_e4m3fn, torch.float8_e5m2]):
192-
if(W_q.dtype == torch.float16):
205+
if(W_q.dtype in [torch.float16, torch.bfloat16, torch.int8, torch.float8_e4m3fn, torch.float8_e5m2]):
206+
if(W_q.dtype in [torch.float16, torch.bfloat16]):
193207
assert self.W_nbits == 16, "Invalid fp16 weights."
194208
else:
195209
assert self.W_nbits == 8, "Invalid 8-bit weights."
@@ -281,7 +295,7 @@ def pack(self, W_q: Tensor, scales: Tensor, zeros: Union[Tensor, int], bias: Uni
281295
self.scales = torch.tensor([[]], dtype=torch.int32, device=self.device)
282296

283297
if(self.scales is not None):
284-
self.meta_dtype = DType.FP32 if self.scales.dtype == torch.float32 else DType.FP16
298+
self.meta_dtype = TORCH_TO_DTYPE[self.scales.dtype]
285299

286300
#Force contiguous
287301
if(contiguous):

gemlite/dtypes.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,18 @@ class DType(Enum):
2525
8: torch.float8_e5m2,
2626
}
2727

28+
TORCH_TO_DTYPE = {
29+
torch.float32: DType.FP32,
30+
torch.float16: DType.FP16,
31+
torch.bfloat16: DType.BF16,
32+
torch.float8_e4m3fn: DType.FP8,
33+
torch.int8: DType.INT8,
34+
torch.uint8: DType.UINT8,
35+
torch.int32: DType.INT32,
36+
torch.uint32: DType.UINT32,
37+
torch.float8_e5m2: DType.FP8e5,
38+
}
39+
2840
TORCH_DTYPE_TO_TRITON = {
2941
torch.float16: tl.float16,
3042
torch.float32: tl.float32,

0 commit comments

Comments
 (0)