Skip to content

Integrate memory tools into default examples #55

@ezyang

Description

@ezyang

I found a diff like this useful for debugging memory problems:

diff --git a/examples/example_llama3.py b/examples/example_llama3.py
index e94f663..4c061d7 100644
--- a/examples/example_llama3.py
+++ b/examples/example_llama3.py
@@ -586,7 +586,7 @@ device = torch.device("cuda")

 def model_fn():
     model_args = TransformerModelArgs(
-        n_layers=8, vocab_size=vocab_size, max_seq_len=seqlen
+        n_layers=2, vocab_size=vocab_size, max_seq_len=seqlen
     )
     m = Transformer(model_args)
     return m
@@ -628,6 +628,8 @@ with AutoParallel(model, input_fn, mesh) as autop:
     parallel_mod = autop.apply_placement(sharding_placement)

 # run weight init on our sharded DTensor params
+torch.cuda.memory._record_memory_history(max_entries=100000)
+
 parallel_mod.to_empty(device="cuda")
 parallel_mod.init_weights()

@@ -643,3 +645,7 @@ x = (
 out = parallel_mod(*x)
 out.backward(torch.randn_like(out))
 print("All good!")
+
+torch.cuda.memory._dump_snapshot("mini_new3.pickle")
+
+print(torch.cuda.memory_summary())

Would be good to actually commit this to all of our examples in some way that's useful for other people.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions