File tree 1 file changed +2
-2
lines changed
torchao/prototype/grouped_mm
1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change 4
4
5
5
from torchao .float8 .config import Float8LinearConfig , Float8LinearRecipeName
6
6
from torchao .float8 .float8_scaling_utils import (
7
- get_maybe_axiswise_dim ,
8
7
hp_tensor_to_float8_dynamic ,
9
8
)
10
9
from torchao .float8 .float8_tensor import GemmInputRole , LinearMMConfig
@@ -31,7 +30,6 @@ def _grouped_scaled_mm(
31
30
out_dtype (Optional[torch.dtype]): The dtype of the output tensor. Currently only torch.bfloat16 is supported.
32
31
use_fast_accum (bool): Whether to use fast accumulation or not. Default is False.
33
32
"""
34
- # perform dynamic float8 quantization using the given recipe, if specified
35
33
return _Float8GroupedMM .apply (
36
34
A ,
37
35
B ,
@@ -43,6 +41,8 @@ def _grouped_scaled_mm(
43
41
44
42
45
43
class _Float8GroupedMM (torch .autograd .Function ):
44
+ """Differentiable implementation of grouped GEMM with dynamic float8 quantization."""
45
+
46
46
@staticmethod
47
47
def forward (
48
48
ctx ,
You can’t perform that action at this time.
0 commit comments