Skip to content

Commit 80b7630

Browse files
rename var
1 parent dc40622 commit 80b7630

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

torchao/prototype/grouped_mm/__init__.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def _grouped_scaled_mm(
2424
2525
Args:
2626
A (torch.Tensor): The first input tensor, which can be 2D or 3D.
27-
B_t (torch.Tensor): The second input tensor which must be 3D.
27+
B (torch.Tensor): The second input tensor which must be 3D.
2828
float8_recipe (Float8LinearRecipeName): The recipe to use for dynamic float8 quantization.
2929
offs (Optional[torch.Tensor]): The offsets to use to mark the starting index of each group. This
3030
is required when 2D A tensor is used, otherwise it should be None.
@@ -47,7 +47,7 @@ class _Float8GroupedMM(torch.autograd.Function):
4747
def forward(
4848
ctx,
4949
A: torch.Tensor,
50-
B_t: torch.Tensor,
50+
B: torch.Tensor,
5151
float8_recipe_name: Float8LinearRecipeName,
5252
offs: Optional[torch.Tensor] = None,
5353
out_dtype: Optional[torch.dtype] = None,
@@ -60,7 +60,7 @@ def forward(
6060

6161
# perform dynamic float8 quantization using the given recipe, if specified
6262
assert 2 <= A.ndim <= 3, "A must be 2D or 3D"
63-
assert B_t.ndim == 3, "B must be 3D"
63+
assert B.ndim == 3, "B must be 3D"
6464

6565
# offsets are required for 2D A tensor, otherwise it should be None.
6666
if A.ndim == 2:
@@ -87,8 +87,8 @@ def forward(
8787
)
8888

8989
# Convert high precision weight tensor to float8.
90-
B_t_fp8 = hp_tensor_to_float8_dynamic(
91-
B_t,
90+
B_fp8 = hp_tensor_to_float8_dynamic(
91+
B,
9292
float8_config.cast_config_input.target_dtype,
9393
linear_mm_config=LinearMMConfig(),
9494
gemm_input_role=GemmInputRole.WEIGHT,
@@ -100,20 +100,20 @@ def forward(
100100
)
101101

102102
# Store what we need for backward.
103-
ctx.save_for_backward(A, B_t)
103+
ctx.save_for_backward(A, B)
104104
ctx.float_config = float8_config
105105
ctx.offs = offs
106106

107107
# For rowwise scaling, torch._scaled_grouped_mm requires scales without any empty dims.
108108
A_fp8._scale = A_fp8._scale.squeeze()
109-
B_t_fp8._scale = B_t_fp8._scale.squeeze()
109+
B_fp8._scale = B_fp8._scale.squeeze()
110110

111111
# Perform scaled grouped GEMM and return result.
112112
return torch._scaled_grouped_mm(
113113
A_fp8._data,
114-
B_t_fp8._data,
114+
B_fp8._data,
115115
A_fp8._scale,
116-
B_t_fp8._scale,
116+
B_fp8._scale,
117117
offs,
118118
out_dtype=out_dtype,
119119
use_fast_accum=use_fast_accum,

0 commit comments

Comments
 (0)