Skip to content

Commit 90b99ba

Browse files
docstring
1 parent c19bc88 commit 90b99ba

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

torchao/prototype/grouped_mm/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
from torchao.float8.config import Float8LinearConfig, Float8LinearRecipeName
66
from torchao.float8.float8_scaling_utils import (
7-
get_maybe_axiswise_dim,
87
hp_tensor_to_float8_dynamic,
98
)
109
from torchao.float8.float8_tensor import GemmInputRole, LinearMMConfig
@@ -31,7 +30,6 @@ def _grouped_scaled_mm(
3130
out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported.
3231
use_fast_accum (bool): Whether to use fast accumulation or not. Default is False.
3332
"""
34-
# perform dynamic float8 quantization using the given recipe, if specified
3533
return _Float8GroupedMM.apply(
3634
A,
3735
B,
@@ -43,6 +41,8 @@ def _grouped_scaled_mm(
4341

4442

4543
class _Float8GroupedMM(torch.autograd.Function):
44+
"""Differentiable implementation of grouped GEMM with dynamic float8 quantization."""
45+
4646
@staticmethod
4747
def forward(
4848
ctx,

0 commit comments

Comments
 (0)