Is there way to offload training memory to DRAM (using FSDP2?) for training Llama3-8B with torchtitan? #620
Description
I am training Llama3-8B using 2 RTX A6000ada 48GB, but got OOM. Is there way to offload training memory to DRAM (using FSDP2?) for training Llama3-8B with torchtitan?
Thanks!
***Error message:
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 112.00 MiB. GPU 0 has a total capacity of 47.48 GiB of which 92.81 MiB is free. Including non-PyTorch memory, this process has 46.71 GiB memory in use. Of the allocated memory 45.56 GiB is allocated by PyTorch, and 448.27 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
***Here is my training config:
torchtitan Config.toml
NOTE: this toml config is a preset for 64 A100 GPUs.
[job]
dump_folder = "./outputs"
description = "Llama 3 8B training"
[profiling]
enable_profiling = true
save_traces_folder = "profile_trace"
profile_freq = 100
[metrics]
log_freq = 10
enable_tensorboard = true
save_tb_folder = "tb"
[model]
name = "llama3"
flavor = "8B"
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / fused_rmsnorm
tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model"
[optimizer]
name = "AdamW"
lr = 3e-4
[training]
batch_size = 2 #1
seq_len = 256 #512 #8192
warmup_steps = 200 # lr scheduler warm up
max_norm = 1.0 # grad norm clipping
steps = 1000
data_parallel_replicate_degree = 1 #1
data_parallel_shard_degree = -1 #-1
tensor_parallel_degree = 2 #1
compile = true
dataset = "c4"
[experimental]
pipeline_parallel_degree = 1 #1
enable_async_tensor_parallel = true
[checkpoint]
enable_checkpoint = false #false
folder = "checkpoint"
interval_type = "steps"
interval = 500
model_weights_only = false
export_dtype = "bfloat16" #32
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
[activation_checkpoint]
mode = 'selective' # ['none', 'selective', 'full']
selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy
[float8]
enable_float8_linear = true
enable_fsdp_float8_all_gather = true
precompute_float8_dynamic_scale_for_fsdp = true