Skip to content

Commit 02180cd

Browse files
ksivamanKshitijLakhani
authored andcommitted
Fix fp8_buf for Linear and LayerNormLinear (#1633)
* Fix fp8_buf for Linear and LayerNormLinear Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fix Signed-off-by: Kirthi Shankar Sivamani <[email protected]> --------- Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
1 parent 069dd8b commit 02180cd

File tree

3 files changed

+28
-4
lines changed

3 files changed

+28
-4
lines changed

transformer_engine/pytorch/module/layernorm_linear.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -1262,6 +1262,7 @@ def forward(
12621262
inp: torch.Tensor,
12631263
is_first_microbatch: Optional[bool] = None,
12641264
fp8_output: Optional[bool] = False,
1265+
fp8_grad: Optional[bool] = False,
12651266
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
12661267
"""
12671268
Apply layer normalization to the input followed by a linear transformation.
@@ -1292,6 +1293,13 @@ def forward(
12921293
if skip_fp8_weight_update is not None:
12931294
is_first_microbatch = False
12941295

1296+
if self.ub_overlap_rs_fprop:
1297+
if get_ub(self.ub_name + "_fprop").is_fp8_ubuf():
1298+
fp8_output = True
1299+
if self.ub_overlap_rs_dgrad:
1300+
if get_ub(self.ub_name + "_dgrad").is_fp8_ubuf():
1301+
fp8_grad = True
1302+
12951303
with self.prepare_forward(
12961304
inp, allow_non_contiguous=False # removed .contiguous from inside the layer
12971305
) as inp:
@@ -1319,7 +1327,7 @@ def forward(
13191327
output_quantizer,
13201328
grad_output_quantizer,
13211329
grad_input_quantizer,
1322-
) = self._get_quantizers(fp8_output)
1330+
) = self._get_quantizers(fp8_output, fp8_grad)
13231331

13241332
if torch.is_grad_enabled():
13251333
fwd_fn = _LayerNormLinear.apply
@@ -1384,7 +1392,7 @@ def forward(
13841392
return out, ln_out
13851393
return out
13861394

1387-
def _get_quantizers(self, fp8_output):
1395+
def _get_quantizers(self, fp8_output, fp8_grad):
13881396
if not self.fp8:
13891397
return [None] * 5
13901398
grad_input_quantizer = None
@@ -1399,6 +1407,8 @@ def _get_quantizers(self, fp8_output):
13991407
if torch.is_grad_enabled():
14001408
grad_output_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1]
14011409
grad_output_quantizer.internal = True
1410+
if fp8_grad:
1411+
grad_input_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1]
14021412

14031413
return (
14041414
input_quantizer,

transformer_engine/pytorch/module/layernorm_mlp.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -1436,6 +1436,11 @@ def forward(
14361436
if skip_fp8_weight_update is not None:
14371437
is_first_microbatch = False
14381438

1439+
fp8_output = False
1440+
if self.ub_overlap_rs:
1441+
if get_ub("fc2_fprop").is_fp8_ubuf():
1442+
fp8_output = True
1443+
14391444
with self.prepare_forward(inp, num_gemms=2) as inp:
14401445
# Get quantizers
14411446
(
@@ -1447,7 +1452,7 @@ def forward(
14471452
grad_fc1_output_quantizer,
14481453
grad_fc2_output_quantizer,
14491454
grad_input_quantizer,
1450-
) = self._get_quantizers()
1455+
) = self._get_quantizers(fp8_output)
14511456

14521457
# Get weight tensors
14531458
fc1_weight = self.fc1_weight
@@ -1533,7 +1538,7 @@ def forward(
15331538
return out, ln_out
15341539
return out
15351540

1536-
def _get_quantizers(self):
1541+
def _get_quantizers(self, fp8_output):
15371542
(
15381543
fc1_input_quantizer,
15391544
fc1_weight_quantizer,
@@ -1555,6 +1560,8 @@ def _get_quantizers(self):
15551560
)
15561561
fc2_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_WEIGHT]
15571562
fc2_weight_quantizer.internal = True
1563+
if fp8_output:
1564+
output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_OUTPUT]
15581565
if torch.is_grad_enabled():
15591566
grad_fc2_output_quantizer = self.quantizers["scaling_bwd"][
15601567
tex.FP8BwdTensors.GRAD_OUTPUT1

transformer_engine/pytorch/module/linear.py

+7
Original file line numberDiff line numberDiff line change
@@ -1104,6 +1104,13 @@ def forward(
11041104
if skip_fp8_weight_update is not None:
11051105
is_first_microbatch = False
11061106

1107+
if self.ub_overlap_rs_fprop:
1108+
if get_ub(self.ub_name + "_fprop").is_fp8_ubuf():
1109+
fp8_output = True
1110+
if self.ub_overlap_rs_dgrad:
1111+
if get_ub(self.ub_name + "_dgrad").is_fp8_ubuf():
1112+
fp8_grad = True
1113+
11071114
with self.prepare_forward(
11081115
inp,
11091116
allow_non_contiguous=isinstance(inp, QuantizedTensor),

0 commit comments

Comments
 (0)