-
Notifications
You must be signed in to change notification settings - Fork 15
Open
Description
P2069426742
graphsafe_run_with_rng_state is a HOP needed to replay a checkpointed op with the same RNG state as it was in the forward. It should be used whenever we want to recompute RNG ops. However, you can see in the paste, we are somehow only using it in the forward, and not in the backward.
Putting aside the fact that this op should generally be saved and not recomputed.
# No stacktrace found for following nodes
graphsafe_run_with_rng_state_2 = torch.ops.higher_order.graphsafe_run_with_rng_state(torch.ops.aten._scaled_dot_product_cudnn_attention.default, permute_25, permute_26, permute_27, None, True, 0.0, True, rng_state = fwd_rng_state_2); permute_25 = permute_26 = permute_27 = fwd_rng_state_2 = None
# File: /home/xmfan/core/a/torchtitan/torchtitan/models/attention.py:164 in forward, code: return F.scaled_dot_product_attention(q, k, v, scale=scale, is_causal=True)
getitem_20: "bf16[4, 16, 2048, 16][524288, 16, 256, 1]cuda:0" = graphsafe_run_with_rng_state_2[0]
getitem_21: "f32[4, 16, 2048, 1][32768, 2048, 1, 1]cuda:0" = graphsafe_run_with_rng_state_2[1]
getitem_26: "i64[][]cuda:0" = graphsafe_run_with_rng_state_2[6]
getitem_27: "i64[][]cuda:0" = graphsafe_run_with_rng_state_2[7]; graphsafe_run_with_rng_state_2 = NoneThis is the backward, we saved instead:
getitem_29: "bf16[4, 16, 2048, 16][524288, 16, 256, 1]cuda:0",
getitem_30: "f32[4, 16, 2048, 1][32768, 2048, 1, 1]cuda:0",
getitem_35: "i64[][]cuda:0",
getitem_36: "i64[][]cuda:0",
...
# File: /home/xmfan/core/a/torchtitan/torchtitan/models/attention.py:164 in forward, code: return F.scaled_dot_product_attention(q, k, v, scale=scale, is_causal=True)
_scaled_dot_product_cudnn_attention_backward_2 = torch.ops.aten._scaled_dot_product_cudnn_attention_backward.default(permute_151, permute_36, permute_37, permute_38, getitem_29, getitem_30, getitem_35, getitem_36, None, None, None, 2048, 2048, 0.0, True); permute_151 = permute_36 = permute_37 = permute_38 = getitem_29 = getitem_30 = getitem_35 = getitem_36 = None
getitem_62: "bf16[4, 16, 2048, 16][524288, 16, 256, 1]cuda:0" = _scaled_dot_product_cudnn_attention_backward_2[0]
getitem_63: "bf16[4, 16, 2048, 16][524288, 16, 256, 1]cuda:0" = _scaled_dot_product_cudnn_attention_backward_2[1]
getitem_64: "bf16[4, 16, 2048, 16][524288, 16, 256, 1]cuda:0" = _scaled_dot_product_cudnn_attention_backward_2[2]; _scaled_dot_product_cudnn_attention_backward_2 = NoneMetadata
Metadata
Assignees
Labels
No labels
