-
Notifications
You must be signed in to change notification settings - Fork 15
Open
Description
I found a diff like this useful for debugging memory problems:
diff --git a/examples/example_llama3.py b/examples/example_llama3.py
index e94f663..4c061d7 100644
--- a/examples/example_llama3.py
+++ b/examples/example_llama3.py
@@ -586,7 +586,7 @@ device = torch.device("cuda")
def model_fn():
model_args = TransformerModelArgs(
- n_layers=8, vocab_size=vocab_size, max_seq_len=seqlen
+ n_layers=2, vocab_size=vocab_size, max_seq_len=seqlen
)
m = Transformer(model_args)
return m
@@ -628,6 +628,8 @@ with AutoParallel(model, input_fn, mesh) as autop:
parallel_mod = autop.apply_placement(sharding_placement)
# run weight init on our sharded DTensor params
+torch.cuda.memory._record_memory_history(max_entries=100000)
+
parallel_mod.to_empty(device="cuda")
parallel_mod.init_weights()
@@ -643,3 +645,7 @@ x = (
out = parallel_mod(*x)
out.backward(torch.randn_like(out))
print("All good!")
+
+torch.cuda.memory._dump_snapshot("mini_new3.pickle")
+
+print(torch.cuda.memory_summary())
Would be good to actually commit this to all of our examples in some way that's useful for other people.
Metadata
Metadata
Assignees
Labels
No labels