Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 62 additions & 34 deletions bench/gemm/gemm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,20 @@
MACHETE_ENABLED = False


try:
from torchao.prototype.moe_training.kernels.mxfp8 import (
triton_mx_block_rearrange_2d_M_groups,
triton_mx_block_rearrange_per_group_3d,
)
from torchao.prototype.mx_formats.kernels import (
triton_mx_block_rearrange,
triton_to_mxfp8_dim0,
)

TORCHAO_ENABLED = True
except ImportError:
TORCHAO_ENABLED = False

gemm_op_registry = []


Expand Down Expand Up @@ -518,10 +532,16 @@ def __init__(self):
self.torch_compile = False

def quantize(self, x, w):
x_scale, xq = to_mxfp8(x)
x_scale = _to_blocked(x_scale)
w_scale, wq = to_mxfp8(w)
w_scale = _to_blocked(w_scale)
if TORCHAO_ENABLED:
xq, x_scale = triton_to_mxfp8_dim0(x)
wq, w_scale = triton_to_mxfp8_dim0(w)
x_scale = triton_mx_block_rearrange(x_scale)
w_scale = triton_mx_block_rearrange(w_scale)
else:
x_scale, xq = to_mxfp8(x)
x_scale = _to_blocked(x_scale)
w_scale, wq = to_mxfp8(w)
w_scale = _to_blocked(w_scale)
return xq, wq.t(), x_scale, w_scale

def compute(self, xq, wq, x_scale, w_scale):
Expand Down Expand Up @@ -2971,36 +2991,44 @@ def quantize(self, x, w):
group_size, total_M + 1, group_size, dtype=torch.int32, device=x.device
)

# For each constituent 2d subtensor in the 3d weights, quantize and convert scale to blocked format separately,
# as they each used for independent gemm in the grouped gemm.
wq_list = []
w_scale_list = []
for i in range(G):
w_scale, wq = to_mxfp8(w[i])
w_scale = _to_blocked(w_scale)
wq_list.append(wq)
w_scale_list.append(w_scale)
wq = torch.stack(wq_list, dim=0).contiguous()
w_scale = torch.stack(w_scale_list, dim=0).contiguous()

# For each group along `total_M` in the 2D tensor, quantize and convert scale to blocked format separately,
# as they each used for independent gemm in the grouped gemm.
xq_list = []
x_scale_list = []
for i in range(G):
prev_group_end = 0 if i == 0 else input_group_end_offsets[i - 1]
curr_group_end = input_group_end_offsets[i]
group_size = curr_group_end - prev_group_end
if group_size > 0:
x_slice = x[prev_group_end:curr_group_end, :]
x_scale, xq = to_mxfp8(x_slice)
x_scale = _to_blocked(x_scale)
xq_list.append(xq)
x_scale_list.append(x_scale)
xq = torch.cat(xq_list, dim=0).contiguous()
x_scale = torch.cat(x_scale_list, dim=0).contiguous()
x_scale = x_scale.reshape(-1, K // block_size)
xq = xq.view(-1, xq.shape[-1])
if TORCHAO_ENABLED:
xq, x_scale = triton_to_mxfp8_dim0(x)
wq, w_scale = triton_to_mxfp8_dim0(w)
x_scale = triton_mx_block_rearrange_2d_M_groups(
x_scale, input_group_end_offsets
)
w_scale = triton_mx_block_rearrange_per_group_3d(w_scale)
else:
# For each constituent 2d subtensor in the 3d weights, quantize and convert scale to blocked format separately,
# as they each used for independent gemm in the grouped gemm.
wq_list = []
w_scale_list = []
for i in range(G):
w_scale, wq = to_mxfp8(w[i])
w_scale = _to_blocked(w_scale)
wq_list.append(wq)
w_scale_list.append(w_scale)
wq = torch.stack(wq_list, dim=0).contiguous()
w_scale = torch.stack(w_scale_list, dim=0).contiguous()

# For each group along `total_M` in the 2D tensor, quantize and convert scale to blocked format separately,
# as they each used for independent gemm in the grouped gemm.
xq_list = []
x_scale_list = []
for i in range(G):
prev_group_end = 0 if i == 0 else input_group_end_offsets[i - 1]
curr_group_end = input_group_end_offsets[i]
group_size = curr_group_end - prev_group_end
if group_size > 0:
x_slice = x[prev_group_end:curr_group_end, :]
x_scale, xq = to_mxfp8(x_slice)
x_scale = _to_blocked(x_scale)
xq_list.append(xq)
x_scale_list.append(x_scale)
xq = torch.cat(xq_list, dim=0).contiguous()
x_scale = torch.cat(x_scale_list, dim=0).contiguous()
x_scale = x_scale.reshape(-1, K // block_size)
xq = xq.view(-1, xq.shape[-1])
return xq, wq, x_scale, w_scale, input_group_end_offsets

def compute(self, xq, wq, x_scale, w_scale, input_group_end_offsets):
Expand Down
Loading