Skip to content

Commit c55e425

Browse files
timmoon10KshitijLakhani
authored andcommitted
[PyTorch] Debug weight matrix usages for dgrad GEMM (#1637)
Make sure that weight matrix has required usages for dgrad GEMM Signed-off-by: Tim Moon <[email protected]>
1 parent 8e0853a commit c55e425

File tree

2 files changed

+3
-4
lines changed

2 files changed

+3
-4
lines changed

transformer_engine/pytorch/module/layernorm_linear.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -327,9 +327,8 @@ def forward(
327327
ln_out.update_usage(rowwise_usage=False)
328328

329329
# Weight with column-wise usage is needed for dgrad GEMM.
330-
if inp.requires_grad:
331-
if isinstance(weightmat, QuantizedTensor):
332-
weightmat.update_usage(columnwise_usage=True)
330+
if isinstance(weightmat, QuantizedTensor):
331+
weightmat.update_usage(columnwise_usage=True)
333332

334333
if cpu_offloading:
335334
if fp8 and weightmat is not None:

transformer_engine/pytorch/module/layernorm_mlp.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,7 @@ def forward(
415415
)
416416

417417
# Weight with column-wise usage is needed for dgrad GEMM.
418-
if is_grad_enabled and inp.requires_grad:
418+
if is_grad_enabled:
419419
if isinstance(fc1_weight_final, QuantizedTensor):
420420
fc1_weight_final.update_usage(columnwise_usage=True)
421421
if isinstance(fc2_weight_final, QuantizedTensor):

0 commit comments

Comments
 (0)