Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multinode support in torchtune #2301

Open
wants to merge 26 commits into
base: main
Choose a base branch
from

Conversation

joecummings
Copy link
Contributor

@joecummings joecummings commented Jan 27, 2025

Officially declaring multi-node open for business in torchtune!

Context

This has been an explicit ask by several users (#2161, #2142) and although things should work fairly easily OOTB, we hadn't been able to test it and therefore didn't want to make any claims that we couldn't back up. Therefore, I sent myself on a quest to set up my own SLURM cluster, waste a lot of money, and test our multi-node scripts!

Changes

  • Fully deprecate get_world_size_and_rank from training in recipes
  • Add get_distributed_backend method that mirrors torchtitan's
  • Update full_finetune_distributed to utilize get_distributed_backend instead of setting "cuda:nccl,cpu:gloo"
  • Add full_finetune_multinode.slurm script
  • Add tutorial on multi node in torchtune

Testing

Experiments were run with a SLURM cluster w/ 2 worker nodes set up on Nebius AI, following this tutorial: https://docs.nebius.com/compute/clusters/slurm.

Node IP: slurm-worker-1
I0129 19:29:59.590000 8654 torch/distributed/run.py:675] Using nproc_per_node=8.
W0129 19:29:59.591000 8654 torch/distributed/run.py:792]
W0129 19:29:59.591000 8654 torch/distributed/run.py:792] *****************************************
W0129 19:29:59.591000 8654 torch/distributed/run.py:792] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0129 19:29:59.591000 8654 torch/distributed/run.py:792] *****************************************
I0129 19:29:59.591000 8654 torch/distributed/launcher/api.py:195] Starting elastic_operator with launch configs:
I0129 19:29:59.591000 8654 torch/distributed/launcher/api.py:195]   entrypoint       : /mnt/slurm/torchtune/lib/python3.10/site-packages/recipes/full_finetune_distributed.py
I0129 19:29:59.591000 8654 torch/distributed/launcher/api.py:195]   min_nodes        : 2
I0129 19:29:59.591000 8654 torch/distributed/launcher/api.py:195]   max_nodes        : 2
I0129 19:29:59.591000 8654 torch/distributed/launcher/api.py:195]   nproc_per_node   : 8
I0129 19:29:59.591000 8654 torch/distributed/launcher/api.py:195]   run_id           : 101
I0129 19:29:59.591000 8654 torch/distributed/launcher/api.py:195]   rdzv_backend     : c10d
I0129 19:29:59.591000 8654 torch/distributed/launcher/api.py:195]   rdzv_endpoint    : slurm-worker-1:29500
I0129 19:29:59.591000 8654 torch/distributed/launcher/api.py:195]   rdzv_configs     : {'timeout': 900}
I0129 19:29:59.591000 8654 torch/distributed/launcher/api.py:195]   max_restarts     : 0
I0129 19:29:59.591000 8654 torch/distributed/launcher/api.py:195]   monitor_interval : 0.1
I0129 19:29:59.591000 8654 torch/distributed/launcher/api.py:195]   log_dir          : /tmp/torchelastic_xcn6sasp
I0129 19:29:59.591000 8654 torch/distributed/launcher/api.py:195]   metrics_cfg      : {}
I0129 19:29:59.591000 8654 torch/distributed/launcher/api.py:195]
I0129 19:29:59.626000 8689 torch/distributed/run.py:675] Using nproc_per_node=8.
W0129 19:29:59.627000 8689 torch/distributed/run.py:792]
W0129 19:29:59.627000 8689 torch/distributed/run.py:792] *****************************************
W0129 19:29:59.627000 8689 torch/distributed/run.py:792] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0129 19:29:59.627000 8689 torch/distributed/run.py:792] *****************************************
I0129 19:29:59.627000 8689 torch/distributed/launcher/api.py:195] Starting elastic_operator with launch configs:
I0129 19:29:59.627000 8689 torch/distributed/launcher/api.py:195]   entrypoint       : /mnt/slurm/torchtune/lib/python3.10/site-packages/recipes/full_finetune_distributed.py
I0129 19:29:59.627000 8689 torch/distributed/launcher/api.py:195]   min_nodes        : 2
I0129 19:29:59.627000 8689 torch/distributed/launcher/api.py:195]   max_nodes        : 2
I0129 19:29:59.627000 8689 torch/distributed/launcher/api.py:195]   nproc_per_node   : 8
I0129 19:29:59.627000 8689 torch/distributed/launcher/api.py:195]   run_id           : 101
I0129 19:29:59.627000 8689 torch/distributed/launcher/api.py:195]   rdzv_backend     : c10d
I0129 19:29:59.627000 8689 torch/distributed/launcher/api.py:195]   rdzv_endpoint    : slurm-worker-1:29500
I0129 19:29:59.627000 8689 torch/distributed/launcher/api.py:195]   rdzv_configs     : {'timeout': 900}
I0129 19:29:59.627000 8689 torch/distributed/launcher/api.py:195]   max_restarts     : 0
I0129 19:29:59.627000 8689 torch/distributed/launcher/api.py:195]   monitor_interval : 0.1
I0129 19:29:59.627000 8689 torch/distributed/launcher/api.py:195]   log_dir          : /tmp/torchelastic_pbqauwt1
I0129 19:29:59.627000 8689 torch/distributed/launcher/api.py:195]   metrics_cfg      : {}
I0129 19:29:59.627000 8689 torch/distributed/launcher/api.py:195]
I0129 19:29:59.632000 8689 torch/distributed/elastic/agent/server/api.py:860] [default] starting workers for entrypoint: python3
I0129 19:29:59.632000 8689 torch/distributed/elastic/agent/server/api.py:677] [default] Rendezvous'ing worker group
I0129 19:29:59.943000 8654 torch/distributed/elastic/agent/server/api.py:860] [default] starting workers for entrypoint: python3
I0129 19:29:59.944000 8654 torch/distributed/elastic/agent/server/api.py:677] [default] Rendezvous'ing worker group
I0129 19:30:00.934000 8689 torch/distributed/elastic/agent/server/api.py:525] [default] Rendezvous complete for workers. Result:
I0129 19:30:00.934000 8689 torch/distributed/elastic/agent/server/api.py:525]   restart_count=0
I0129 19:30:00.934000 8689 torch/distributed/elastic/agent/server/api.py:525]   master_addr=computeinstance-e00ffcpw4hzd6ws0gp
I0129 19:30:00.934000 8689 torch/distributed/elastic/agent/server/api.py:525]   master_port=37069
I0129 19:30:00.934000 8689 torch/distributed/elastic/agent/server/api.py:525]   group_rank=1
I0129 19:30:00.934000 8689 torch/distributed/elastic/agent/server/api.py:525]   group_world_size=2
I0129 19:30:00.934000 8689 torch/distributed/elastic/agent/server/api.py:525]   local_ranks=[0, 1, 2, 3, 4, 5, 6, 7]
I0129 19:30:00.934000 8689 torch/distributed/elastic/agent/server/api.py:525]   role_ranks=[8, 9, 10, 11, 12, 13, 14, 15]
I0129 19:30:00.934000 8689 torch/distributed/elastic/agent/server/api.py:525]   global_ranks=[8, 9, 10, 11, 12, 13, 14, 15]
I0129 19:30:00.934000 8689 torch/distributed/elastic/agent/server/api.py:525]   role_world_sizes=[16, 16, 16, 16, 16, 16, 16, 16]
I0129 19:30:00.934000 8689 torch/distributed/elastic/agent/server/api.py:525]   global_world_sizes=[16, 16, 16, 16, 16, 16, 16, 16]
I0129 19:30:00.934000 8689 torch/distributed/elastic/agent/server/api.py:525]
I0129 19:30:00.934000 8689 torch/distributed/elastic/agent/server/api.py:685] [default] Starting worker group
I0129 19:30:00.934000 8689 torch/distributed/elastic/agent/server/local_elastic_agent.py:298] use_agent_store: True
I0129 19:30:00.934000 8654 torch/distributed/elastic/agent/server/api.py:525] [default] Rendezvous complete for workers. Result:
I0129 19:30:00.934000 8654 torch/distributed/elastic/agent/server/api.py:525]   restart_count=0
I0129 19:30:00.934000 8654 torch/distributed/elastic/agent/server/api.py:525]   master_addr=computeinstance-e00ffcpw4hzd6ws0gp
I0129 19:30:00.934000 8654 torch/distributed/elastic/agent/server/api.py:525]   master_port=37069
I0129 19:30:00.934000 8654 torch/distributed/elastic/agent/server/api.py:525]   group_rank=0
I0129 19:30:00.934000 8654 torch/distributed/elastic/agent/server/api.py:525]   group_world_size=2
I0129 19:30:00.934000 8654 torch/distributed/elastic/agent/server/api.py:525]   local_ranks=[0, 1, 2, 3, 4, 5, 6, 7]
I0129 19:30:00.934000 8654 torch/distributed/elastic/agent/server/api.py:525]   role_ranks=[0, 1, 2, 3, 4, 5, 6, 7]
I0129 19:30:00.934000 8654 torch/distributed/elastic/agent/server/api.py:525]   global_ranks=[0, 1, 2, 3, 4, 5, 6, 7]
I0129 19:30:00.935000 8689 torch/distributed/elastic/agent/server/local_elastic_agent.py:192] Environment variable 'TORCHELASTIC_ENABLE_FILE_TIMER' not found. Do not start FileTimerServer.
I0129 19:30:00.934000 8654 torch/distributed/elastic/agent/server/api.py:525]   role_world_sizes=[16, 16, 16, 16, 16, 16, 16, 16]
I0129 19:30:00.934000 8654 torch/distributed/elastic/agent/server/api.py:525]   global_world_sizes=[16, 16, 16, 16, 16, 16, 16, 16]
I0129 19:30:00.934000 8654 torch/distributed/elastic/agent/server/api.py:525]
I0129 19:30:00.935000 8689 torch/distributed/elastic/agent/server/local_elastic_agent.py:236] Environment variable 'TORCHELASTIC_HEALTH_CHECK_PORT' not found. Do not start health check.
I0129 19:30:00.934000 8654 torch/distributed/elastic/agent/server/api.py:685] [default] Starting worker group
I0129 19:30:00.935000 8654 torch/distributed/elastic/agent/server/local_elastic_agent.py:298] use_agent_store: True
I0129 19:30:00.935000 8654 torch/distributed/elastic/agent/server/local_elastic_agent.py:192] Environment variable 'TORCHELASTIC_ENABLE_FILE_TIMER' not found. Do not start FileTimerServer.
I0129 19:30:00.935000 8654 torch/distributed/elastic/agent/server/local_elastic_agent.py:236] Environment variable 'TORCHELASTIC_HEALTH_CHECK_PORT' not found. Do not start health check.
Running FullFinetuneRecipeDistributed with resolved config:

batch_size: 8
checkpoint_dir: /mnt/slurm/Llama-3.3-70B-Instruct
checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /mnt/slurm/Llama3.3-70B/
  checkpoint_files:
  - model-00001-of-00030.safetensors
  - model-00002-of-00030.safetensors
  - model-00003-of-00030.safetensors
  - model-00004-of-00030.safetensors
  - model-00005-of-00030.safetensors
  - model-00006-of-00030.safetensors
  - model-00007-of-00030.safetensors
  - model-00008-of-00030.safetensors
  - model-00009-of-00030.safetensors
  - model-00010-of-00030.safetensors
  - model-00011-of-00030.safetensors
  - model-00012-of-00030.safetensors
  - model-00013-of-00030.safetensors
  - model-00014-of-00030.safetensors
  - model-00015-of-00030.safetensors
  - model-00016-of-00030.safetensors
  - model-00017-of-00030.safetensors
  - model-00018-of-00030.safetensors
  - model-00019-of-00030.safetensors
  - model-00020-of-00030.safetensors
  - model-00021-of-00030.safetensors
  - model-00022-of-00030.safetensors
  - model-00023-of-00030.safetensors
  - model-00024-of-00030.safetensors
  - model-00025-of-00030.safetensors
  - model-00026-of-00030.safetensors
  - model-00027-of-00030.safetensors
  - model-00028-of-00030.safetensors
  - model-00029-of-00030.safetensors
  - model-00030-of-00030.safetensors
  model_type: LLAMA3
  output_dir: /mnt/slurm/Llama3.3-70B-fft-output
  recipe_checkpoint: null
compile: true
custom_sharded_layers:
- tok_embeddings
- output
dataset:
  _component_: torchtune.datasets.alpaca_dataset
  packed: true
device: cuda
dtype: bf16
enable_activation_checkpointing: true
enable_activation_offloading: false
epochs: 1
fsdp_cpu_offload: false
gradient_accumulation_steps: 1
log_every_n_steps: 1
log_peak_memory_stats: true
loss:
  _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
metric_logger:
  _component_: torchtune.training.metric_logging.WandBLogger
  dir: /mnt/slurm/Llama3.3-70B-fft-output
  name: llama3.3-70B-v0-compile-1024
  project: torchtune-multinode
model:
  _component_: torchtune.models.llama3_3.llama3_3_70b
optimizer:
  _component_: torch.optim.AdamW
  fused: true
  lr: 2.0e-05
optimizer_in_bwd: false
output_dir: /mnt/slurm/Llama3.3-70B-fft-output
profiler:
  _component_: torchtune.training.setup_torch_profiler
  active_steps: 2
  cpu: true
  cuda: true
  enabled: false
  num_cycles: 1
  output_dir: /mnt/slurm/Llama3.3-70B-fft-output/profiling_outputs
  profile_memory: false
  record_shapes: true
  wait_steps: 5
  warmup_steps: 3
  with_flops: false
  with_stack: false
resume_from_checkpoint: false
seed: null
shuffle: true
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  max_seq_len: 1024
  path: /mnt/slurm/Llama3.3-70B/original/tokenizer.model

Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. Enabling activation offloading should reduce memory further.
Setting manual seed to local seed 983303568. Local seed is seed + rank = 983303568 + 0
wandb: Currently logged in as: jcummings. Use `wandb login --relogin` to force relogin
wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: Tracking run with wandb version 0.19.4
wandb: Run data is saved locally in /mnt/slurm/Llama3.3-70B-fft-output/wandb/run-20250129_193034-m3m6k5gf
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run llama3.3-70B-v0-compile-1024
wandb: ⭐️ View project at https://wandb.ai/jcummings/torchtune-multinode
wandb: 🚀 View run at https://wandb.ai/jcummings/torchtune-multinode/runs/m3m6k5gf
/mnt/slurm/torchtune/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:863: UserWarning: `_get_pg_default_device` will be deprecated, it only stays for backward-compatiblity reason. If you need to find a device for object collectives, please use `_get_object_coll_device`. If you need to query the device types supported by group, please use `_device_capability(group)`.
  warnings.warn(
/mnt/slurm/torchtune/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:863: UserWarning: `_get_pg_default_device` will be deprecated, it only stays for backward-compatiblity reason. If you need to find a device for object collectives, please use `_get_object_coll_device`. If you need to query the device types supported by group, please use `_device_capability(group)`.
  warnings.warn(
/mnt/slurm/torchtune/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:863: UserWarning: `_get_pg_default_device` will be deprecated, it only stays for backward-compatiblity reason. If you need to find a device for object collectives, please use `_get_object_coll_device`. If you need to query the device types supported by group, please use `_device_capability(group)`.
  warnings.warn(
/mnt/slurm/torchtune/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:863: UserWarning: `_get_pg_default_device` will be deprecated, it only stays for backward-compatiblity reason. If you need to find a device for object collectives, please use `_get_object_coll_device`. If you need to query the device types supported by group, please use `_device_capability(group)`.
  warnings.warn(
/mnt/slurm/torchtune/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:863: UserWarning: `_get_pg_default_device` will be deprecated, it only stays for backward-compatiblity reason. If you need to find a device for object collectives, please use `_get_object_coll_device`. If you need to query the device types supported by group, please use `_device_capability(group)`.
  warnings.warn(
/mnt/slurm/torchtune/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:863: UserWarning: `_get_pg_default_device` will be deprecated, it only stays for backward-compatiblity reason. If you need to find a device for object collectives, please use `_get_object_coll_device`. If you need to query the device types supported by group, please use `_device_capability(group)`.
  warnings.warn(
/mnt/slurm/torchtune/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:863: UserWarning: `_get_pg_default_device` will be deprecated, it only stays for backward-compatiblity reason. If you need to find a device for object collectives, please use `_get_object_coll_device`. If you need to query the device types supported by group, please use `_device_capability(group)`.
  warnings.warn(
FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ...
Compiling model layers with torch.compile...
/mnt/slurm/torchtune/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:863: UserWarning: `_get_pg_default_device` will be deprecated, it only stays for backward-compatiblity reason. If you need to find a device for object collectives, please use `_get_object_coll_device`. If you need to query the device types supported by group, please use `_device_capability(group)`.
  warnings.warn(
Instantiating model and loading checkpoint took 52.94 secs
Memory stats after model init:
	GPU peak memory allocation: 8.28 GiB
	GPU peak memory reserved: 8.60 GiB
	GPU peak memory active: 8.28 GiB
Optimizer is initialized.
Compiling loss with torch.compile...
Loss is initialized.
Packing dataset: 100%|██████████| 52002/52002 [00:11<00:00, 4612.22it/s]
Dataset and Sampler are initialized.
No learning rate scheduler configured. Using constant learning rate.
 Profiling disabled.
 Profiler config after instantiation: {'enabled': False}
  0%|          | 0/48 [00:00<?, ?it/s]Using flex attention for attention computation since a BlockMask was passed in.
1|48|Loss: 0.9388514161109924: 100%|██████████| 48/48 [07:22<00:00,  8.44s/it]Saving checkpoint. This may take some time. Retrieving full model state dict...
Getting full model state dict took 67.36 secs
Model checkpoint of size 4.27 GiB saved to /mnt/slurm/Llama3.3-70B-fft-output/epoch_0/ft-model-00001-of-00030.safetensors
Model checkpoint of size 4.34 GiB saved to /mnt/slurm/Llama3.3-70B-fft-output/epoch_0/ft-model-00002-of-00030.safetensors
Model checkpoint of size 4.66 GiB saved to /mnt/slurm/Llama3.3-70B-fft-output/epoch_0/ft-model-00003-of-00030.safetensors
Model checkpoint of size 4.63 GiB saved to /mnt/slurm/Llama3.3-70B-fft-output/epoch_0/ft-model-00004-of-00030.safetensors
Model checkpoint of size 4.34 GiB saved to /mnt/slurm/Llama3.3-70B-fft-output/epoch_0/ft-model-00005-of-00030.safetensors
Model checkpoint of size 4.34 GiB saved to /mnt/slurm/Llama3.3-70B-fft-output/epoch_0/ft-model-00006-of-00030.safetensors
Model checkpoint of size 4.34 GiB saved to /mnt/slurm/Llama3.3-70B-fft-output/epoch_0/ft-model-00007-of-00030.safetensors
Model checkpoint of size 4.66 GiB saved to /mnt/slurm/Llama3.3-70B-fft-output/epoch_0/ft-model-00008-of-00030.safetensors
Model checkpoint of size 4.63 GiB saved to /mnt/slurm/Llama3.3-70B-fft-output/epoch_0/ft-model-00009-of-00030.safetensors
Model checkpoint of size 4.34 GiB saved to /mnt/slurm/Llama3.3-70B-fft-output/epoch_0/ft-model-00010-of-00030.safetensors
Model checkpoint of size 4.34 GiB saved to /mnt/slurm/Llama3.3-70B-fft-output/epoch_0/ft-model-00011-of-00030.safetensors
Model checkpoint of size 4.34 GiB saved to /mnt/slurm/Llama3.3-70B-fft-output/epoch_0/ft-model-00012-of-00030.safetensors
Model checkpoint of size 4.66 GiB saved to /mnt/slurm/Llama3.3-70B-fft-output/epoch_0/ft-model-00013-of-00030.safetensors
Model checkpoint of size 4.63 GiB saved to /mnt/slurm/Llama3.3-70B-fft-output/epoch_0/ft-model-00014-of-00030.safetensors
Model checkpoint of size 4.34 GiB saved to /mnt/slurm/Llama3.3-70B-fft-output/epoch_0/ft-model-00015-of-00030.safetensors
Model checkpoint of size 4.34 GiB saved to /mnt/slurm/Llama3.3-70B-fft-output/epoch_0/ft-model-00016-of-00030.safetensors
Model checkpoint of size 4.34 GiB saved to /mnt/slurm/Llama3.3-70B-fft-output/epoch_0/ft-model-00017-of-00030.safetensors
Model checkpoint of size 4.66 GiB saved to /mnt/slurm/Llama3.3-70B-fft-output/epoch_0/ft-model-00018-of-00030.safetensors
Model checkpoint of size 4.63 GiB saved to /mnt/slurm/Llama3.3-70B-fft-output/epoch_0/ft-model-00019-of-00030.safetensors
Model checkpoint of size 4.34 GiB saved to /mnt/slurm/Llama3.3-70B-fft-output/epoch_0/ft-model-00020-of-00030.safetensors
Model checkpoint of size 4.34 GiB saved to /mnt/slurm/Llama3.3-70B-fft-output/epoch_0/ft-model-00021-of-00030.safetensors
Model checkpoint of size 4.34 GiB saved to /mnt/slurm/Llama3.3-70B-fft-output/epoch_0/ft-model-00022-of-00030.safetensors
Model checkpoint of size 4.66 GiB saved to /mnt/slurm/Llama3.3-70B-fft-output/epoch_0/ft-model-00023-of-00030.safetensors
Model checkpoint of size 4.63 GiB saved to /mnt/slurm/Llama3.3-70B-fft-output/epoch_0/ft-model-00024-of-00030.safetensors
Model checkpoint of size 4.34 GiB saved to /mnt/slurm/Llama3.3-70B-fft-output/epoch_0/ft-model-00025-of-00030.safetensors
Model checkpoint of size 4.34 GiB saved to /mnt/slurm/Llama3.3-70B-fft-output/epoch_0/ft-model-00026-of-00030.safetensors
Model checkpoint of size 4.34 GiB saved to /mnt/slurm/Llama3.3-70B-fft-output/epoch_0/ft-model-00027-of-00030.safetensors
Model checkpoint of size 4.66 GiB saved to /mnt/slurm/Llama3.3-70B-fft-output/epoch_0/ft-model-00028-of-00030.safetensors
Model checkpoint of size 4.63 GiB saved to /mnt/slurm/Llama3.3-70B-fft-output/epoch_0/ft-model-00029-of-00030.safetensors
Model checkpoint of size 1.96 GiB saved to /mnt/slurm/Llama3.3-70B-fft-output/epoch_0/ft-model-00030-of-00030.safetensors
Saving final epoch checkpoint.
The full model checkpoint, including all weights and configurations, has been saved successfully.You can now use this checkpoint for further training or inference.
Saving checkpoint took 623.90 secs
1|48|Loss: 0.9388514161109924: 100%|██████████| 48/48 [17:49<00:00, 22.28s/it]
wandb:
wandb:
wandb: Run history:
wandb:               global_step ▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
wandb:                      loss █▄▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb:                        lr ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb:        peak_memory_active ▁███████████████████████████████████████
wandb:         peak_memory_alloc ▁███████████████████████████████████████
wandb:      peak_memory_reserved ▁███████████████████████████████████████
wandb: tokens_per_second_per_gpu ▁███████████████████████████████████████
wandb:
wandb: Run summary:
wandb:               global_step 48
wandb:                      loss 0.93885
wandb:                        lr 2e-05
wandb:        peak_memory_active 44.68959
wandb:         peak_memory_alloc 44.68959
wandb:      peak_memory_reserved 60.72656
wandb: tokens_per_second_per_gpu 890.92334
wandb:
wandb: 🚀 View run llama3.3-70B-v0-compile-1024 at: https://wandb.ai/jcummings/torchtune-multinode/runs/m3m6k5gf
wandb: ⭐️ View project at: https://wandb.ai/jcummings/torchtune-multinode
wandb: Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 1 other file(s)
wandb: Find logs at: ./Llama3.3-70B-fft-output/wandb/run-20250129_193034-m3m6k5gf/logs
I0129 19:55:56.785000 8689 torch/distributed/elastic/agent/server/api.py:879] [default] worker group successfully finished. Waiting 300 seconds for other agents to finish.
I0129 19:55:56.787000 8689 torch/distributed/elastic/agent/server/api.py:932] Local worker group finished (WorkerState.SUCCEEDED). Waiting 300 seconds for other agents to finish
I0129 19:56:00.076000 8654 torch/distributed/elastic/agent/server/api.py:879] [default] worker group successfully finished. Waiting 300 seconds for other agents to finish.
I0129 19:56:00.078000 8654 torch/distributed/elastic/agent/server/api.py:932] Local worker group finished (WorkerState.SUCCEEDED). Waiting 300 seconds for other agents to finish
I0129 19:56:00.079000 8654 torch/distributed/elastic/agent/server/api.py:946] Done waiting for other agents. Elapsed: 0.0015294551849365234 seconds
I0129 19:56:00.080000 8689 torch/distributed/elastic/agent/server/api.py:946] Done waiting for other agents. Elapsed: 3.2924535274505615 seconds
Running with torchrun...
Running with torchrun...

Weights&Biases Snapshot:
Screenshot 2025-01-29 at 3 08 43 PM

Follow ups

  • Add get_distributed_backend to ALL distributed recipes
  • Test multi-node with distributed LoRA

Copy link

pytorch-bot bot commented Jan 27, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2301

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures, 2 Cancelled Jobs

As of commit 63205da with merge base e6b9064 (image):

NEW FAILURES - The following jobs have failed:

CANCELLED JOBS - The following jobs were cancelled. Please retry:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 27, 2025
@joecummings joecummings changed the title Remove last references to from training Multinode support in torchtune Jan 27, 2025
if device_type in dist.Backend.default_device_backend_map.keys():
backend = dist.default_device_backend_map.get(device_type)
if enable_cpu_offload:
backend = f"{device_type}:{backend},cpu:gloo"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think gloo backend will also be necessary for async save

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cite your sources

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found the sources.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Screenshot 2025-01-30 at 2 33 28 PM


**Low inter-node bandwidth & FSDP**
We utilize <FSDP> to distribute models over multiple devices. In order to distribute training, FSDP runs an all-gather operation for each forward pass and an all-gather plus a scatter-reduce
operation for each backwards pass. These operations (usually) block training from continuing until completed and with a slow inter-node connection, training speed may be reduced.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why would we have all-gather in backward? I thought it's all gather in forward and reduce-scatter in backward?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good question! The default for FSDP is to reshard after forward in order to save memory. If resharded, they need to be all-gathered before the backwards pass, too. If not, then you are correct, there's no reason to all-gather again.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gotcha, maybe explain that in the readme?


What else do you want?

BLAH BLHAH BALSHD 很好
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😅

# tune download meta-llama/Llama-3.3-70B-Instruct --ignore-patterns "original/consolidated*" --output-dir SHARED_CLUSTER_FS
#
# To launch on 2 nodes w/ 8 devices on a SLURM cluster, run the following command:
# sbatch full_finetune_multinode.slurm
Copy link
Contributor

@acisseJZhong acisseJZhong Jan 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a way to make the full_finetune_multinode.slurm takes in an argument to specify which config/model to run, instead of creating a new config for mutlinode?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Arghh, this would be a good idea. I'm leaning towards just trying to get this up there as an example of how to run since you'll really need to modify the SLURM file itself in order to set the correct number of nodes, etc.

Open to thoughts though. cc @ebsmothers @pbontrager

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I agree with @acisseJZhong's suggestion. Also I think the concept of recipes + configs breaks down a bit here. I think we should either very explicitly say "this is just a demo and is not a real recipe" (i.e. we don't even list it in recipes), or we should properly integrate with tune run -- i.e. if one specifies tune run --nnodes {>1} ... we dispatch to a generic slurm script on the backend (this is just one UX.. could also require explicit --slurm arg or something like that)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bleh

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make them copy it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay will make them copy it for now and not add to recipe registry, but I will keep the script there.

if self.fsdp_cpu_offload:
# Utilize all available CPU cores for intra-op parallelism. This provides ~2x
# speed up when benchmarking fused AdamW on CPU
training.set_torch_num_threads()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we always want to set this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point. Looks like this was added by Rohan, so not sure who to follow up with here. Let me dig into it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is a heuristic for fused Adam on CPU when CPU offload is enabled. I don't think it's optimal, but I do think that without it CPU offload training may be much slower

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it be set for async offload too? Or pure CPU training?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Afaik it shouldn't matter for async offload and mostly has to do with fused optimizer. For pure CPU training I guess the optimizer step also happens on CPU so in that case we would potentially want it

@@ -240,9 +245,16 @@ def setup(self, cfg: DictConfig) -> None:
Setup the recipe. This includes training state (if resume_from_checkpoint is True),
model, tokenizer, loss, optimizer, lr scheduler, sampler, and dataloader.
"""
# Set up the backend for distributed training (NCCL, GLOO, etc.)
init_process_group(self.distributed_backend)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious why do we want to move this block from recipe_main to setup?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my mind, this is doing actual setup. Therefore it should belong with the rest of the setup code, not buried at the bottom of the recipe where it's hard to find.

@@ -240,9 +245,16 @@ def setup(self, cfg: DictConfig) -> None:
Setup the recipe. This includes training state (if resume_from_checkpoint is True),
model, tokenizer, loss, optimizer, lr scheduler, sampler, and dataloader.
"""
# Set up the backend for distributed training (NCCL, GLOO, etc.)
init_process_group(self.distributed_backend)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you also update generate_v2_distributed recipe?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Follow-up for all distributed recipes :)

Joe Cummings added 2 commits January 29, 2025 11:21
@joecummings joecummings marked this pull request as ready for review January 29, 2025 16:27
@SalmanMohammadi SalmanMohammadi mentioned this pull request Jan 30, 2025
9 tasks
Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work and 10/10 tutorial

More machines means more memory! This is cool for several reasons:

1. **Bigger models**: With more memory, you can train larger models such as `Llama3.1 405B <https://ai.meta.com/blog/meta-llama-3-1/>`_, `Deepseek-V3 <https://www.deepseek.com/>`_, and more.
2. **Longer data**: More many tasks like writing code, it's helpful to have long context lengths; however longer context length means more memory needed for activations.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separately would be a little bit careful about how we frame this. Like we don't actually have context parallel yet so don't wanna imply that people can continually scale context length with # of nodes.

Suggested change
2. **Longer data**: More many tasks like writing code, it's helpful to have long context lengths; however longer context length means more memory needed for activations.
2. **Longer data**: For many tasks like writing code, it's helpful to have long context lengths; however longer context length means more memory needed for activations.

Multi-node finetuning
=====================

Congratulations! After years of being "GPU poor", you've worked hard, saved your hard earned Bitcoin and graduated to the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry but discussions of crypto are banned on our docs


.. note::

**Low inter-node bandwidth & FSDP** We utilize `Fully Sharded Data Parallel <https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/>`_ to distribute models over multiple devices. In order to distribute training, FSDP runs an `all-gather <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allgather>`_ operation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally I would not point to this FSDP blog post as pretty much all the APIs given there are moot for torchtune's purposes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair.

.. note::

**Low inter-node bandwidth & FSDP** We utilize `Fully Sharded Data Parallel <https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/>`_ to distribute models over multiple devices. In order to distribute training, FSDP runs an `all-gather <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allgather>`_ operation
for each forward pass and an all-gather plus a `scatter-reduce <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#reducescatter>`_ operation for each backwards pass. These operations (usually) block training from continuing until completed and with a slow
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for each forward pass and an all-gather plus a `scatter-reduce <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#reducescatter>`_ operation for each backwards pass. These operations (usually) block training from continuing until completed and with a slow
for each forward pass and an all-gather plus a `reduce-scatter <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#reducescatter>`_ operation for each backwards pass. These operations (usually) block training from continuing until completed and with a slow

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've heard it both ways.


Now that we have a downloaded model, let's check out our example SLURM bash script.

.. literalinclude:: ../../../recipes/full_finetune_multinode.slurm
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, didn't know you could do this. But one nit is that it includes the license, which looks a little weird in the docs imo

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can take it out of the recipes and just have people copy and paste from the tutorial? Less findable from Github tho.

Comment on lines +110 to +113
Config(
name="llama3_3/70B_full_multinode",
file_path="llama3_3/70B_full_multinode.yaml",
),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

K so are we keeping this in? I get we wanna show in tune ls but also it won't actually work with just tune run, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the only difference with this one is that I turn off a bunch of the memory optimization (b/c we don't need them with multi-node!) . I'm happy to rename to something like _fast, but _multinode really explains what it's for.

Comment on lines +32 to +36
# You probably want to load in a virtual env w/ conda...
# module load conda
# conda activate torchtune
# ...or venv
# source torchtune/bin/activate
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this commented out? Is it because we don't know the user's venv/conda env? I remember wasting a bunch of time myself on this kinda stuff before, might be worth explicitly calling it out in the tutorial

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we can't make any assumptions about how they initialize their virtual env

@@ -11,6 +11,7 @@
from torchtune.training._compile import compile_loss, compile_model
from torchtune.training._distributed import (
gather_cpu_state_dict,
get_distributed_backend,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we fully deprecating get_world_size_and_rank in this PR? Seems like the API still exists and is imported here too

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The API just moved to utils from training.

self.distributed_backend = training.get_distributed_backend(
device_type,
offload_ops_to_cpu=self.fsdp_cpu_offload
or self._enable_async_checkpointing,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you run with async checkpointing? Pretty interested to know how much time it saves on multiple nodes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope.

"cpu": "gloo",
"xpu": "xccl",
}
# TODO: Uncomment the following line once PyTorch 2.6 is released
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's released

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, then big question: Do we force people to upgrade immediately to PyTorch 2.6? Otherwise, this function will not work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants