Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support FP8 grouped GEMM with rowwise scailing #3560

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 32 additions & 64 deletions fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,22 +522,39 @@ def quantize(self, x, w):

def compute(self, xq, wq, x_scale, w_scale, m_values, kernel_name=None):
if m_values is None:
return torch.ops.fbgemm.f8f8bf16_rowwise_grouped(
xq,
wq,
x_scale,
w_scale,
kernel_name=kernel_name,
)
if torch.version.cuda:
return torch.ops.fbgemm.f8f8bf16_grouped(
xq,
wq,
x_scale,
w_scale,
)
else:
return torch.ops.fbgemm.f8f8bf16_rowwise_grouped(
xq,
wq,
x_scale,
w_scale,
kernel_name=kernel_name,
)
else:
return torch.ops.fbgemm.f8f8bf16_rowwise_grouped_dynamic(
xq,
wq,
x_scale,
w_scale,
zero_start_index_M=m_values,
kernel_name=kernel_name,
)
if torch.version.cuda:
return torch.ops.fbgemm.f8f8bf16_grouped(
xq,
wq,
x_scale,
w_scale,
zero_start_index_M=m_values,
)
else:
return torch.ops.fbgemm.f8f8bf16_rowwise_grouped_dynamic(
xq,
wq,
x_scale,
w_scale,
zero_start_index_M=m_values,
kernel_name=kernel_name,
)

def quantize_and_compute(self, x, w):
xq, wq, x_scale, w_scale, m_values = self.quantize(x, w)
Expand All @@ -554,55 +571,6 @@ def name(self) -> str:
def hip(self) -> bool:
return True

@property
def cuda(self) -> bool:
return False


@register_quantize_op
class FP8GroupedGemm(QuantizeOpBase):
"""
FP8 grouped matmul with tensorwise scaling.
"""

def quantize(self, x, w):
assert isinstance(
x, (list, tuple)
), "Inputs to group gemm must be a list of tensors."

# First check if N and K are fixed.
m_values = [i.shape[0] for i in x]
# Otherwise handle in eager mode.
xq, x_scale = zip(*[torch.ops.fbgemm.quantize_fp8_per_tensor(i) for i in x])
wq, w_scale = zip(*[torch.ops.fbgemm.quantize_fp8_per_tensor(i) for i in w])
joint_scales = [xs * ws for xs, ws in zip(x_scale, w_scale)]
m_values = torch.tensor(m_values).to(dtype=torch.int64, device=xq[0].device)
return xq, wq, joint_scales, m_values

def compute(self, xq, wq, scales, m_values):
return torch.ops.fbgemm.f8f8bf16_grouped(
xq,
wq,
scales,
zero_start_index_M=m_values,
)

def quantize_and_compute(self, x, w):
xq, wq, scales, m_values = self.quantize(x, w)
return self.compute(xq, wq, scales, m_values)

@property
def name(self) -> str:
if torch.version.cuda:
return "cutlass_grouped"
else:
return "ck_grouped"

@property
def hip(self) -> bool:
# Only rowwise grouped is currently supported.
return False

@property
def cuda(self) -> bool:
return True
Expand Down
Loading
Loading