@@ -24,7 +24,7 @@ def _grouped_scaled_mm(
24
24
25
25
Args:
26
26
A (torch.Tensor): The first input tensor, which can be 2D or 3D.
27
- B_t (torch.Tensor): The second input tensor which must be 3D.
27
+ B (torch.Tensor): The second input tensor which must be 3D.
28
28
float8_recipe (Float8LinearRecipeName): The recipe to use for dynamic float8 quantization.
29
29
offs (Optional[torch.Tensor]): The offsets to use to mark the starting index of each group. This
30
30
is required when 2D A tensor is used, otherwise it should be None.
@@ -47,7 +47,7 @@ class _Float8GroupedMM(torch.autograd.Function):
47
47
def forward (
48
48
ctx ,
49
49
A : torch .Tensor ,
50
- B_t : torch .Tensor ,
50
+ B : torch .Tensor ,
51
51
float8_recipe_name : Float8LinearRecipeName ,
52
52
offs : Optional [torch .Tensor ] = None ,
53
53
out_dtype : Optional [torch .dtype ] = None ,
@@ -60,7 +60,7 @@ def forward(
60
60
61
61
# perform dynamic float8 quantization using the given recipe, if specified
62
62
assert 2 <= A .ndim <= 3 , "A must be 2D or 3D"
63
- assert B_t .ndim == 3 , "B must be 3D"
63
+ assert B .ndim == 3 , "B must be 3D"
64
64
65
65
# offsets are required for 2D A tensor, otherwise it should be None.
66
66
if A .ndim == 2 :
@@ -87,8 +87,8 @@ def forward(
87
87
)
88
88
89
89
# Convert high precision weight tensor to float8.
90
- B_t_fp8 = hp_tensor_to_float8_dynamic (
91
- B_t ,
90
+ B_fp8 = hp_tensor_to_float8_dynamic (
91
+ B ,
92
92
float8_config .cast_config_input .target_dtype ,
93
93
linear_mm_config = LinearMMConfig (),
94
94
gemm_input_role = GemmInputRole .WEIGHT ,
@@ -100,20 +100,20 @@ def forward(
100
100
)
101
101
102
102
# Store what we need for backward.
103
- ctx .save_for_backward (A , B_t )
103
+ ctx .save_for_backward (A , B )
104
104
ctx .float_config = float8_config
105
105
ctx .offs = offs
106
106
107
107
# For rowwise scaling, torch._scaled_grouped_mm requires scales without any empty dims.
108
108
A_fp8 ._scale = A_fp8 ._scale .squeeze ()
109
- B_t_fp8 ._scale = B_t_fp8 ._scale .squeeze ()
109
+ B_fp8 ._scale = B_fp8 ._scale .squeeze ()
110
110
111
111
# Perform scaled grouped GEMM and return result.
112
112
return torch ._scaled_grouped_mm (
113
113
A_fp8 ._data ,
114
- B_t_fp8 ._data ,
114
+ B_fp8 ._data ,
115
115
A_fp8 ._scale ,
116
- B_t_fp8 ._scale ,
116
+ B_fp8 ._scale ,
117
117
offs ,
118
118
out_dtype = out_dtype ,
119
119
use_fast_accum = use_fast_accum ,
0 commit comments