Skip to content

Commit cf42af1

Browse files
add backward pass
1 parent c4c6c99 commit cf42af1

File tree

1 file changed

+117
-11
lines changed

1 file changed

+117
-11
lines changed

torchao/prototype/grouped_mm/__init__.py

+117-11
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,12 @@ def forward(
7676
# Fetch float8 config from specified recipe name.
7777
float8_config = Float8LinearConfig.from_recipe_name(float8_recipe_name)
7878

79-
# Convert high precision input tensor to float8.
79+
# Convert high precision input tensor to float8, row-major for left operand of grouped GEMM.
8080
# A shape: (M, K) or (B, M, K)
8181
# A_scale shape: (M,1) or (B, M, 1)
8282
# torch._scaled_grouped_mm requires scales without any empty dims, so squeeze A_scale.
8383
# A_scale shape: (M,) or (B, M)
84-
A_fp8 = hp_tensor_to_float8_dynamic(
84+
A_fp8_row_major = hp_tensor_to_float8_dynamic(
8585
A,
8686
float8_config.cast_config_input.target_dtype,
8787
linear_mm_config=LinearMMConfig(),
@@ -92,15 +92,15 @@ def forward(
9292
),
9393
round_scales_to_power_of_2=float8_config.round_scales_to_power_of_2,
9494
)
95-
A_fp8_scale = A_fp8._scale.squeeze()
95+
A_scale = A_fp8_row_major._scale.squeeze()
9696

97-
# Convert high precision weight tensor to float8.
97+
# Convert B to float8, column-major for right operand of grouped GEMM.
9898
# B shape: (K,N) or (B, K, N)
9999
# B scales must be computed rowwise keeping the outer/final dim, so:
100100
# B_scale shape: (1,N) or (B, 1, N)
101101
# torch._scaled_grouped_mm requires scales without any empty dims, so squeeze A_scale.
102102
# B scale shape: (N,) or (B, N)
103-
B_fp8 = hp_tensor_to_float8_dynamic(
103+
B_fp8_col_major = hp_tensor_to_float8_dynamic(
104104
B,
105105
float8_config.cast_config_input.target_dtype,
106106
linear_mm_config=LinearMMConfig(),
@@ -111,23 +111,129 @@ def forward(
111111
),
112112
round_scales_to_power_of_2=float8_config.round_scales_to_power_of_2,
113113
)
114+
B_scale = B_fp8_col_major._scale.squeeze()
114115

115116
# Store what we need for backward.
116117
ctx.save_for_backward(A, B)
117-
ctx.float_config = float8_config
118+
ctx.float8_config = float8_config
118119
ctx.offs = offs
119120

120121
# Perform scaled grouped GEMM and return result.
122+
# output shape: (M, N) or (B, M, N)
121123
return torch._scaled_grouped_mm(
122-
A_fp8._data,
123-
B_fp8._data,
124-
A_fp8._scale,
125-
B_fp8._scale,
124+
A_fp8_row_major._data,
125+
B_fp8_col_major._data,
126+
A_scale,
127+
B_scale,
126128
offs,
127129
out_dtype=out_dtype,
128130
use_fast_accum=use_fast_accum,
129131
)
130132

131133
@staticmethod
132134
def backward(ctx, grad_output: torch.Tensor):
133-
return None, None, None, None, None, None
135+
A, B = ctx.saved_tensors
136+
offs = ctx.offs
137+
float8_config = ctx.float8_config
138+
139+
# Convert grad_output to float8, row-major for left operand of grouped GEMM.
140+
# grad_output shape: (M, N) or (B, M, N)
141+
# grad_output_scale shape: (M, 1) or (B, M, 1)
142+
# squeeze grad_output_scale to remove empty dim, as required by torch._scaled_grouped_mm.
143+
# grad_output_scale shape: (M,) or (B, M)
144+
grad_output_fp8_row_major = hp_tensor_to_float8_dynamic(
145+
grad_output,
146+
float8_config.cast_config_grad_output.target_dtype,
147+
linear_mm_config=LinearMMConfig(),
148+
gemm_input_role=GemmInputRole.GRAD_OUTPUT,
149+
scaling_granularity=float8_config.cast_config_grad_output.scaling_granularity,
150+
axiswise_dim=get_maybe_axiswise_dim(
151+
-1, float8_config.cast_config_grad_output.scaling_granularity
152+
),
153+
round_scales_to_power_of_2=float8_config.round_scales_to_power_of_2,
154+
)
155+
grad_output_scale = grad_output_fp8_row_major._scale.squeeze()
156+
157+
# Convert B to float8, column-major for right operand of grouped GEMM.
158+
# B shape: (K,N) or (B, K, N)
159+
# B scales must be computed rowwise keeping the outer/final dim, so:
160+
# B_scale shape: (1,N) or (B, 1, N)
161+
# torch._scaled_grouped_mm requires scales without any empty dims, so squeeze A_scale.
162+
# B scale shape: (N,) or (B, N)
163+
B_fp8_col_major = hp_tensor_to_float8_dynamic(
164+
B,
165+
float8_config.cast_config_input.target_dtype,
166+
linear_mm_config=LinearMMConfig(),
167+
gemm_input_role=GemmInputRole.WEIGHT,
168+
scaling_granularity=float8_config.cast_config_weight.scaling_granularity,
169+
axiswise_dim=get_maybe_axiswise_dim(
170+
1, float8_config.cast_config_input.scaling_granularity
171+
),
172+
round_scales_to_power_of_2=float8_config.round_scales_to_power_of_2,
173+
)
174+
B_scale = B_fp8_col_major._scale.squeeze()
175+
176+
# grad_input = grad_output @ weight
177+
# grad_A = grad_output @ B
178+
grad_A = torch._scaled_grouped_mm(
179+
grad_output_fp8_row_major._data,
180+
B_fp8_col_major._data,
181+
grad_output_scale,
182+
B_scale,
183+
offs,
184+
out_dtype=grad_output.dtype,
185+
use_fast_accum=False,
186+
)
187+
188+
# Convert tranpose of grad_output to float8, row-major for left operand of grouped GEMM.
189+
grad_output_t = grad_output.transpose(-2, -1)
190+
191+
# grad_output_t shape: (N, M) or (B, N, M)
192+
# grad_output_t_scale shape: (N, 1) or (B, N, 1)
193+
# squeeze grad_output_t_scale to remove empty dim, as required by torch._scaled_grouped_mm.
194+
# grad_output_t_scale shape: (N,) or (B, N)
195+
grad_output_t_fp8 = hp_tensor_to_float8_dynamic(
196+
grad_output_t,
197+
float8_config.cast_config_grad_output.target_dtype,
198+
linear_mm_config=LinearMMConfig(),
199+
gemm_input_role=GemmInputRole.GRAD_OUTPUT,
200+
scaling_granularity=float8_config.cast_config_grad_output.scaling_granularity,
201+
axiswise_dim=get_maybe_axiswise_dim(
202+
-1, float8_config.cast_config_grad_output.scaling_granularity
203+
),
204+
round_scales_to_power_of_2=float8_config.round_scales_to_power_of_2,
205+
)
206+
grad_output_t_scale = grad_output_t_fp8._scale.squeeze()
207+
208+
# Convert A to float8, column-major for right operand of grouped GEMM.
209+
# A shape: (M, K) or (B, M, K)
210+
# A scales must be computed rowwise keeping the outer/final dim, for right operand in grouped GEMM, so:
211+
# A_scale shape: (1,K) or (B, 1, K)
212+
# torch._scaled_grouped_mm requires scales without any empty dims, so squeeze A_scale.
213+
# A scale shape: (K,) or (B, K)
214+
A_fp8 = hp_tensor_to_float8_dynamic(
215+
A,
216+
float8_config.cast_config_input.target_dtype,
217+
linear_mm_config=LinearMMConfig(),
218+
gemm_input_role=GemmInputRole.INPUT,
219+
scaling_granularity=float8_config.cast_config_input.scaling_granularity,
220+
axiswise_dim=get_maybe_axiswise_dim(
221+
1, float8_config.cast_config_input.scaling_granularity
222+
),
223+
round_scales_to_power_of_2=float8_config.round_scales_to_power_of_2,
224+
)
225+
A_scale = A_fp8._scale.squeeze()
226+
227+
# grad_weight = grad_output_t @ input
228+
# grad_B = grad_output_t @ A
229+
grad_B = torch._scaled_grouped_mm(
230+
grad_output_t_fp8._data,
231+
A_fp8._data,
232+
grad_output_t_scale,
233+
A_scale,
234+
offs,
235+
out_dtype=grad_output.dtype,
236+
use_fast_accum=False,
237+
)
238+
239+
return grad_A, grad_B, None, None, None, None

0 commit comments

Comments
 (0)