Skip to content

Commit 45b2a8f

Browse files
Tianyu Liangfacebook-github-bot
Tianyu Liang
authored andcommitted
Replace torch quantization implementation with Triton version (#4217)
Summary: Pull Request resolved: #4217 X-link: facebookresearch/FBGEMM#1293 Same as title Reviewed By: jiawenliu64 Differential Revision: D75645953 fbshipit-source-id: d5e96df648a6c66d05d599218193e294b3aa3ca8
1 parent 83a537f commit 45b2a8f

File tree

2 files changed

+11
-211
lines changed

2 files changed

+11
-211
lines changed

fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@
1313

1414
import torch
1515
import triton # @manual=//triton:triton
16+
17+
from fbgemm_gpu.experimental.gemm.triton_gemm.fp4_quantize import (
18+
triton_quantize_mx4_unpack,
19+
)
20+
1621
from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import (
1722
matmul_fp8_block,
1823
matmul_fp8_row,
@@ -27,7 +32,6 @@
2732
)
2833
from fbgemm_gpu.experimental.gen_ai.quantize import (
2934
quantize_int4_preshuffle,
30-
scale_mxfp4_quant,
3135
scale_nvfp4_quant,
3236
)
3337

@@ -2056,8 +2060,8 @@ class MXFP4Gemm(QuantizeOpBase):
20562060
"""
20572061

20582062
def quantize(self, x, w):
2059-
xq, x_scale = scale_mxfp4_quant(x)
2060-
wq, w_scale = scale_mxfp4_quant(w)
2063+
xq, x_scale = triton_quantize_mx4_unpack(x)
2064+
wq, w_scale = triton_quantize_mx4_unpack(w)
20612065
return xq, wq, x_scale, w_scale
20622066

20632067
def compute(self, xq, wq, x_scale, w_scale):
@@ -2088,11 +2092,11 @@ class MXFP4GroupedGemm(QuantizeOpBase):
20882092
"""
20892093

20902094
def preprocess(self, x, w):
2091-
wq, w_scale = zip(*[scale_mxfp4_quant(i) for i in w])
2095+
wq, w_scale = zip(*[triton_quantize_mx4_unpack(i) for i in w])
20922096
return x, wq, w_scale
20932097

20942098
def quantize(self, x, wq, w_scale):
2095-
xq, x_scale = zip(*[scale_mxfp4_quant(i) for i in x])
2099+
xq, x_scale = zip(*[triton_quantize_mx4_unpack(i) for i in x])
20962100
return xq, wq, x_scale, w_scale
20972101

20982102
def compute(self, xq, wq, x_scale, w_scale):
@@ -2191,13 +2195,13 @@ class MXFP4StackedGroupedGemm(QuantizeOpBase):
21912195
def preprocess(self, x, w):
21922196
m_values = [i.shape[0] for i in x]
21932197
m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device)
2194-
wq, w_scale = zip(*[scale_mxfp4_quant(i) for i in w])
2198+
wq, w_scale = zip(*[triton_quantize_mx4_unpack(i) for i in w])
21952199
wq = torch.stack(wq, dim=0).contiguous()
21962200
w_scale = torch.stack(w_scale, dim=0).contiguous()
21972201
return x, wq, w_scale, m_sizes
21982202

21992203
def quantize(self, x, wq, w_scale, m_sizes):
2200-
xq, x_scale = zip(*[scale_mxfp4_quant(i) for i in x])
2204+
xq, x_scale = zip(*[triton_quantize_mx4_unpack(i) for i in x])
22012205
xq = torch.stack(xq, dim=0).contiguous()
22022206
x_scale = torch.stack(x_scale, dim=0).contiguous()
22032207
xq = xq.view(-1, xq.shape[-1])

fbgemm_gpu/experimental/gen_ai/gen_ai/quantize.py

Lines changed: 0 additions & 204 deletions
Original file line numberDiff line numberDiff line change
@@ -216,207 +216,3 @@ def round_up(x: int, y: int) -> int:
216216
torch.ops.fbgemm.scaled_fp4_quant(output, input, output_scale, input_global_scale)
217217
output_scale = output_scale.view(torch.float8_e4m3fn)
218218
return output, output_scale
219-
220-
221-
def _fp32_to_fp4_unpacked(x: torch.Tensor, ebits: int, mbits: int) -> torch.Tensor:
222-
"""Converts a float32 tensor to a unpacked float4 tensor.
223-
Args:
224-
x (torch.Tensor): The input float32 tensor.
225-
ebits (int): The number of bits in the exponent.
226-
mbits (int): The number of bits in the mantissa.
227-
Returns:
228-
torch.Tensor: The resulting unpacked float4 tensor.
229-
"""
230-
231-
def _n_ones(n: int) -> int:
232-
return (1 << n) - 1
233-
234-
EBITS_F32, MBITS_F32 = 8, 23
235-
F32_EXP_BIAS = _n_ones(EBITS_F32 - 1)
236-
237-
assert x.dtype == torch.float
238-
assert 1 + ebits + mbits <= 8
239-
240-
# calculate constants
241-
exp_bias = _n_ones(ebits - 1)
242-
max_int = _n_ones(ebits + mbits)
243-
sign_mask = 1 << (ebits + mbits)
244-
245-
magic_adder = _n_ones(MBITS_F32 - mbits - 1)
246-
247-
# all E bits and M bits are 1s
248-
max_normal = 2 ** (_n_ones(ebits) - exp_bias) * (_n_ones(mbits + 1) / (2**mbits))
249-
250-
# E bits = 1, M bits = 0
251-
min_normal = 2 ** (1 - exp_bias)
252-
253-
denorm_exp = (
254-
# exp bias conversion between formats
255-
(F32_EXP_BIAS - exp_bias)
256-
# mantissa length difference between formats
257-
+ (MBITS_F32 - mbits)
258-
# add one to encoded exponent for denormalized numbers
259-
+ 1
260-
)
261-
denorm_mask_int = denorm_exp << MBITS_F32
262-
263-
# reinterpret int32 as float32
264-
denorm_mask_float = torch.tensor(denorm_mask_int, dtype=torch.int32).view(
265-
torch.float32
266-
)
267-
268-
# save the sign
269-
# Note that we have torch.uint32, but some ops like cpu bit shifts
270-
# do not work on it. So, we stay in int32.
271-
x = x.view(torch.int32)
272-
sign = x & 0x80000000
273-
274-
# set everything to positive, will add sign back at the end
275-
x = x ^ sign
276-
x = x.view(torch.float)
277-
278-
# rewrite saturate/denorm/norm branches without explicit data dependent
279-
# control flow, to be more compiler friendly
280-
saturate_mask = x >= max_normal
281-
denormal_mask = torch.logical_and(torch.logical_not(saturate_mask), x < min_normal)
282-
normal_mask = torch.logical_not(torch.logical_or(saturate_mask, denormal_mask))
283-
284-
denormal_x = x + denorm_mask_float
285-
denormal_x = denormal_x.view(torch.int32)
286-
denormal_x -= denorm_mask_int
287-
denormal_x = denormal_x.to(torch.uint8)
288-
289-
normal_x = x.view(torch.int32)
290-
# resulting mantissa is odd
291-
mant_odd = (normal_x >> (MBITS_F32 - mbits)) & 1
292-
# update exponent, rounding bias part 1
293-
val_to_add = ((exp_bias - F32_EXP_BIAS) << MBITS_F32) + magic_adder
294-
normal_x += val_to_add
295-
# rounding bias part 2
296-
normal_x += mant_odd
297-
# take the bits!
298-
normal_x = normal_x >> (MBITS_F32 - mbits)
299-
normal_x = normal_x.to(torch.uint8)
300-
301-
x = torch.full_like(x, max_int, dtype=torch.uint8)
302-
x = torch.where(denormal_mask, denormal_x, x)
303-
x = torch.where(normal_mask, normal_x, x)
304-
305-
# add sign back
306-
sign_lp = sign >> (MBITS_F32 + EBITS_F32 - mbits - ebits)
307-
sign_lp = sign_lp.to(torch.uint8)
308-
# Right shift of a negative signed integer can fill the least significant
309-
# bits with either 1s or 0s, depending on the implementation. Since PyTorch
310-
# doesn't have an uint32 dtype, we mask out these bits to get just the
311-
# f4 sign bit
312-
sign_lp = sign_lp & sign_mask
313-
x = x | sign_lp
314-
315-
return x.to(torch.uint8)
316-
317-
318-
def _to_blocked(x: torch.Tensor) -> torch.Tensor:
319-
"""Converts a tensor to the blocked layout.
320-
Args:
321-
x (torch.Tensor): The input tensor in non-blocked layout.
322-
Returns:
323-
torch.Tensor: The output tensor in the blocked layout.
324-
"""
325-
326-
def ceil_div(a: int, b: int) -> int:
327-
return (a + b - 1) // b
328-
329-
rows, cols = x.shape
330-
n_row_blocks = ceil_div(rows, 128)
331-
n_col_blocks = ceil_div(cols, 4)
332-
333-
# Calculate the padded shape
334-
padded_rows = n_row_blocks * 128
335-
padded_cols = n_col_blocks * 4
336-
337-
padded = x
338-
if (rows, cols) != (padded_rows, padded_cols):
339-
padded = torch.zeros(
340-
(padded_rows, padded_cols),
341-
device=x.device,
342-
dtype=x.dtype,
343-
)
344-
padded[:rows, :cols] = x
345-
346-
# Rearrange the blocks
347-
blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)
348-
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
349-
350-
return rearranged.flatten()
351-
352-
353-
# This PyTorch version refers to https://github.com/pytorch/ao/blob/v0.10.0/torchao/prototype/mx_formats/mx_tensor.py#L146
354-
def scale_mxfp4_quant(
355-
x: torch.Tensor, block_size: int = 32
356-
) -> Tuple[torch.Tensor, torch.Tensor]:
357-
"""
358-
Quantize input tensor to FP4 and return quantized tensor and scale.
359-
Args:
360-
x (torch.Tensor): The input tensor to be quantized to FP4
361-
block_size (int): The block size to use for quantization. Default is 32.
362-
Returns:
363-
xq (torch.Tensor): Quantized FP4 output tensor
364-
scale (torch.Tensor): Scale E8M0 tensor
365-
"""
366-
367-
F4_E2M1_MAX = 6.0
368-
E8M0_EXPONENT_BIAS = 127
369-
EBITS_F4_E2M1, MBITS_F4_E2M1 = 2, 1
370-
371-
# calculate the scale in e8m0 format
372-
orig_shape = x.shape
373-
x = x.reshape(-1, block_size)
374-
375-
# find max value of the data
376-
# Note: this only implements the `minimally supported` version of
377-
# https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
378-
# section 6.3.
379-
max_abs = torch.amax(torch.abs(x), 1)
380-
max_pos = F4_E2M1_MAX
381-
382-
descale = max_abs / max_pos
383-
scale = torch.where(
384-
torch.isnan(descale),
385-
0xFF, # Handle biased exponent for nan
386-
# NOTE: descale < (torch.finfo(torch.float32).smallest_normal / 2) is handled through clamping
387-
(
388-
torch.clamp(
389-
torch.ceil(torch.log2(descale)),
390-
min=-E8M0_EXPONENT_BIAS,
391-
max=E8M0_EXPONENT_BIAS,
392-
)
393-
+ E8M0_EXPONENT_BIAS
394-
).to(torch.uint8),
395-
)
396-
397-
descale_fp = torch.where(
398-
scale == 0,
399-
1.0,
400-
torch.exp2(E8M0_EXPONENT_BIAS - scale.to(torch.float32)),
401-
)
402-
403-
# scale and saturated cast the data elements to max of target dtype
404-
xq = torch.clamp(x * descale_fp.unsqueeze(1), min=-1 * max_pos, max=max_pos)
405-
406-
xq = xq.reshape(orig_shape)
407-
xq = _fp32_to_fp4_unpacked(xq, EBITS_F4_E2M1, MBITS_F4_E2M1)
408-
orig_shape = [*orig_shape[:-1], orig_shape[-1] // 2]
409-
410-
shape = xq.shape
411-
assert shape[-1] % 2 == 0
412-
xq = xq.contiguous().view(-1)
413-
xq = (xq[::2] << 4 | xq[1::2]).view((*shape[:-1], shape[-1] // 2))
414-
415-
target_numel = scale.numel() * block_size / 2
416-
assert target_numel == xq.numel(), f"{target_numel} != {xq.numel()}"
417-
418-
scale = scale.view(torch.float8_e8m0fnu)
419-
scale = scale.view(orig_shape[0], -1)
420-
scale = _to_blocked(scale)
421-
422-
return xq, scale

0 commit comments

Comments
 (0)