diff --git a/torchtitan/experiments/autoparallel/deepseek_v3/parallelize_deepseekv3.py b/torchtitan/experiments/autoparallel/deepseek_v3/parallelize_deepseekv3.py index 0f718a389b..47235ea588 100644 --- a/torchtitan/experiments/autoparallel/deepseek_v3/parallelize_deepseekv3.py +++ b/torchtitan/experiments/autoparallel/deepseek_v3/parallelize_deepseekv3.py @@ -329,6 +329,7 @@ def input_fn(): world_mesh, mp_policy=mp_policy, compile=job_config.compile, + repeated_subgraphs=False, # switch to True to make it run faster ) as autop: autop.add_parameter_memory_constraint(low=None, high=None) diff --git a/torchtitan/experiments/autoparallel/llama3/parallelize_llama.py b/torchtitan/experiments/autoparallel/llama3/parallelize_llama.py index d7fbae2622..67838eeccc 100644 --- a/torchtitan/experiments/autoparallel/llama3/parallelize_llama.py +++ b/torchtitan/experiments/autoparallel/llama3/parallelize_llama.py @@ -91,6 +91,7 @@ def input_fn(): world_mesh, mp_policy=mp_policy, compile=job_config.compile, + repeated_subgraphs=True, # makes it run much faster ) as autop: autop.add_parameter_memory_constraint(low=None, high=None)