diff --git a/torchtitan/models/norms.py b/torchtitan/models/norms.py index 26645330..31527452 100644 --- a/torchtitan/models/norms.py +++ b/torchtitan/models/norms.py @@ -284,7 +284,6 @@ def backward(ctx, dy): M, N = dy.shape dx = torch.empty_like(x) - dw = torch.empty_like(weight) sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)