-
Notifications
You must be signed in to change notification settings - Fork 510
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multinode support in torchtune #2301
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2301
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New Failures, 2 Cancelled JobsAs of commit 63205da with merge base e6b9064 (): NEW FAILURES - The following jobs have failed:
CANCELLED JOBS - The following jobs were cancelled. Please retry:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
if device_type in dist.Backend.default_device_backend_map.keys(): | ||
backend = dist.default_device_backend_map.get(device_type) | ||
if enable_cpu_offload: | ||
backend = f"{device_type}:{backend},cpu:gloo" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think gloo backend will also be necessary for async save
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cite your sources
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I found the sources.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
docs/source/tutorials/multinode.rst
Outdated
|
||
**Low inter-node bandwidth & FSDP** | ||
We utilize <FSDP> to distribute models over multiple devices. In order to distribute training, FSDP runs an all-gather operation for each forward pass and an all-gather plus a scatter-reduce | ||
operation for each backwards pass. These operations (usually) block training from continuing until completed and with a slow inter-node connection, training speed may be reduced. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why would we have all-gather in backward? I thought it's all gather in forward and reduce-scatter in backward?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good question! The default for FSDP is to reshard after forward in order to save memory. If resharded, they need to be all-gathered before the backwards pass, too. If not, then you are correct, there's no reason to all-gather again.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gotcha, maybe explain that in the readme?
docs/source/tutorials/multinode.rst
Outdated
|
||
What else do you want? | ||
|
||
BLAH BLHAH BALSHD 很好 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
😅
# tune download meta-llama/Llama-3.3-70B-Instruct --ignore-patterns "original/consolidated*" --output-dir SHARED_CLUSTER_FS | ||
# | ||
# To launch on 2 nodes w/ 8 devices on a SLURM cluster, run the following command: | ||
# sbatch full_finetune_multinode.slurm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a way to make the full_finetune_multinode.slurm takes in an argument to specify which config/model to run, instead of creating a new config for mutlinode?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Arghh, this would be a good idea. I'm leaning towards just trying to get this up there as an example of how to run since you'll really need to modify the SLURM file itself in order to set the correct number of nodes, etc.
Open to thoughts though. cc @ebsmothers @pbontrager
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I agree with @acisseJZhong's suggestion. Also I think the concept of recipes + configs breaks down a bit here. I think we should either very explicitly say "this is just a demo and is not a real recipe" (i.e. we don't even list it in recipes), or we should properly integrate with tune run -- i.e. if one specifies tune run --nnodes {>1} ...
we dispatch to a generic slurm script on the backend (this is just one UX.. could also require explicit --slurm
arg or something like that)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bleh
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make them copy it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay will make them copy it for now and not add to recipe registry, but I will keep the script there.
if self.fsdp_cpu_offload: | ||
# Utilize all available CPU cores for intra-op parallelism. This provides ~2x | ||
# speed up when benchmarking fused AdamW on CPU | ||
training.set_torch_num_threads() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we always want to set this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good point. Looks like this was added by Rohan, so not sure who to follow up with here. Let me dig into it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah this is a heuristic for fused Adam on CPU when CPU offload is enabled. I don't think it's optimal, but I do think that without it CPU offload training may be much slower
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should it be set for async offload too? Or pure CPU training?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Afaik it shouldn't matter for async offload and mostly has to do with fused optimizer. For pure CPU training I guess the optimizer step also happens on CPU so in that case we would potentially want it
@@ -240,9 +245,16 @@ def setup(self, cfg: DictConfig) -> None: | |||
Setup the recipe. This includes training state (if resume_from_checkpoint is True), | |||
model, tokenizer, loss, optimizer, lr scheduler, sampler, and dataloader. | |||
""" | |||
# Set up the backend for distributed training (NCCL, GLOO, etc.) | |||
init_process_group(self.distributed_backend) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curious why do we want to move this block from recipe_main
to setup
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In my mind, this is doing actual setup. Therefore it should belong with the rest of the setup code, not buried at the bottom of the recipe where it's hard to find.
@@ -240,9 +245,16 @@ def setup(self, cfg: DictConfig) -> None: | |||
Setup the recipe. This includes training state (if resume_from_checkpoint is True), | |||
model, tokenizer, loss, optimizer, lr scheduler, sampler, and dataloader. | |||
""" | |||
# Set up the backend for distributed training (NCCL, GLOO, etc.) | |||
init_process_group(self.distributed_backend) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you also update generate_v2_distributed recipe?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Follow-up for all distributed recipes :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work and 10/10 tutorial
docs/source/tutorials/multinode.rst
Outdated
More machines means more memory! This is cool for several reasons: | ||
|
||
1. **Bigger models**: With more memory, you can train larger models such as `Llama3.1 405B <https://ai.meta.com/blog/meta-llama-3-1/>`_, `Deepseek-V3 <https://www.deepseek.com/>`_, and more. | ||
2. **Longer data**: More many tasks like writing code, it's helpful to have long context lengths; however longer context length means more memory needed for activations. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Separately would be a little bit careful about how we frame this. Like we don't actually have context parallel yet so don't wanna imply that people can continually scale context length with # of nodes.
2. **Longer data**: More many tasks like writing code, it's helpful to have long context lengths; however longer context length means more memory needed for activations. | |
2. **Longer data**: For many tasks like writing code, it's helpful to have long context lengths; however longer context length means more memory needed for activations. |
Multi-node finetuning | ||
===================== | ||
|
||
Congratulations! After years of being "GPU poor", you've worked hard, saved your hard earned Bitcoin and graduated to the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry but discussions of crypto are banned on our docs
docs/source/tutorials/multinode.rst
Outdated
|
||
.. note:: | ||
|
||
**Low inter-node bandwidth & FSDP** We utilize `Fully Sharded Data Parallel <https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/>`_ to distribute models over multiple devices. In order to distribute training, FSDP runs an `all-gather <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allgather>`_ operation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Personally I would not point to this FSDP blog post as pretty much all the APIs given there are moot for torchtune's purposes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair.
docs/source/tutorials/multinode.rst
Outdated
.. note:: | ||
|
||
**Low inter-node bandwidth & FSDP** We utilize `Fully Sharded Data Parallel <https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/>`_ to distribute models over multiple devices. In order to distribute training, FSDP runs an `all-gather <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allgather>`_ operation | ||
for each forward pass and an all-gather plus a `scatter-reduce <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#reducescatter>`_ operation for each backwards pass. These operations (usually) block training from continuing until completed and with a slow |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for each forward pass and an all-gather plus a `scatter-reduce <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#reducescatter>`_ operation for each backwards pass. These operations (usually) block training from continuing until completed and with a slow | |
for each forward pass and an all-gather plus a `reduce-scatter <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#reducescatter>`_ operation for each backwards pass. These operations (usually) block training from continuing until completed and with a slow |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've heard it both ways.
|
||
Now that we have a downloaded model, let's check out our example SLURM bash script. | ||
|
||
.. literalinclude:: ../../../recipes/full_finetune_multinode.slurm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool, didn't know you could do this. But one nit is that it includes the license, which looks a little weird in the docs imo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can take it out of the recipes and just have people copy and paste from the tutorial? Less findable from Github tho.
Config( | ||
name="llama3_3/70B_full_multinode", | ||
file_path="llama3_3/70B_full_multinode.yaml", | ||
), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
K so are we keeping this in? I get we wanna show in tune ls
but also it won't actually work with just tune run
, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the only difference with this one is that I turn off a bunch of the memory optimization (b/c we don't need them with multi-node!) . I'm happy to rename to something like _fast
, but _multinode
really explains what it's for.
# You probably want to load in a virtual env w/ conda... | ||
# module load conda | ||
# conda activate torchtune | ||
# ...or venv | ||
# source torchtune/bin/activate |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this commented out? Is it because we don't know the user's venv/conda env? I remember wasting a bunch of time myself on this kinda stuff before, might be worth explicitly calling it out in the tutorial
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, we can't make any assumptions about how they initialize their virtual env
@@ -11,6 +11,7 @@ | |||
from torchtune.training._compile import compile_loss, compile_model | |||
from torchtune.training._distributed import ( | |||
gather_cpu_state_dict, | |||
get_distributed_backend, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we fully deprecating get_world_size_and_rank
in this PR? Seems like the API still exists and is imported here too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The API just moved to utils from training.
self.distributed_backend = training.get_distributed_backend( | ||
device_type, | ||
offload_ops_to_cpu=self.fsdp_cpu_offload | ||
or self._enable_async_checkpointing, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you run with async checkpointing? Pretty interested to know how much time it saves on multiple nodes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope.
"cpu": "gloo", | ||
"xpu": "xccl", | ||
} | ||
# TODO: Uncomment the following line once PyTorch 2.6 is released |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's released
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, then big question: Do we force people to upgrade immediately to PyTorch 2.6? Otherwise, this function will not work.
Officially declaring multi-node open for business in torchtune!
Context
This has been an explicit ask by several users (#2161, #2142) and although things should work fairly easily OOTB, we hadn't been able to test it and therefore didn't want to make any claims that we couldn't back up. Therefore, I sent myself on a quest to set up my own SLURM cluster, waste a lot of money, and test our multi-node scripts!
Changes
get_world_size_and_rank
fromtraining
in recipesget_distributed_backend
method that mirrors torchtitan'sfull_finetune_distributed
to utilizeget_distributed_backend
instead of setting "cuda:nccl,cpu:gloo"full_finetune_multinode.slurm
scriptTesting
Experiments were run with a SLURM cluster w/ 2 worker nodes set up on Nebius AI, following this tutorial: https://docs.nebius.com/compute/clusters/slurm.
Weights&Biases Snapshot:
Follow ups
get_distributed_backend
to ALL distributed recipes