File tree 2 files changed +3
-4
lines changed
transformer_engine/pytorch/module
2 files changed +3
-4
lines changed Original file line number Diff line number Diff line change @@ -327,9 +327,8 @@ def forward(
327
327
ln_out .update_usage (rowwise_usage = False )
328
328
329
329
# Weight with column-wise usage is needed for dgrad GEMM.
330
- if inp .requires_grad :
331
- if isinstance (weightmat , QuantizedTensor ):
332
- weightmat .update_usage (columnwise_usage = True )
330
+ if isinstance (weightmat , QuantizedTensor ):
331
+ weightmat .update_usage (columnwise_usage = True )
333
332
334
333
if cpu_offloading :
335
334
if fp8 and weightmat is not None :
Original file line number Diff line number Diff line change @@ -415,7 +415,7 @@ def forward(
415
415
)
416
416
417
417
# Weight with column-wise usage is needed for dgrad GEMM.
418
- if is_grad_enabled and inp . requires_grad :
418
+ if is_grad_enabled :
419
419
if isinstance (fc1_weight_final , QuantizedTensor ):
420
420
fc1_weight_final .update_usage (columnwise_usage = True )
421
421
if isinstance (fc2_weight_final , QuantizedTensor ):
You can’t perform that action at this time.
0 commit comments