-
Notifications
You must be signed in to change notification settings - Fork 220
Open
Labels
Description
Hello,
I am trying to train a 120b MOE model using both FSDP and FSDP2. I was able to train using both backends; however, I had to make some specific changes for both backends.
Config details:
- I am using GB200 nodes, 16 nodes for VLLM engines, and 16 for colocating reference and policy models for FSDP. For fsdp2, I had to keep reference and policy models on separate nodes to make it work (i.e, a total of 48 nodes: 16 for policy, 16 for reference, and 16 for vllm engines).
- I am using a global batch size of 1024 (64 unique prompts and 16 samples per prompt), and micro batch sizes of 1,4,8,16 work fine.
- Using SP of 4 for the trainer and TP of 4 for vllm engines.
- The model has input and output embedding weights tied.
FSDP changes:
- The actual model is stored in fp32 precision; when loading the reference model in bf16, it always led to "a ray actor died" issue.
- Upon loading the model in FP32, the policy and reference models load fine even when they are co-located.
FSDP2 changes:
- With the default setting, loading the trainer models failed in all of the following cases: loading policy and reference models on separate nodes, loading them in colocated mode, and only loading the policy model.
- Since our model uses weight tying, the
get_init_weight_context_managerinfsdp_worker.pyscript loaded the models on each rank's CPU instead of just the 0th rank. - Upon bypassing the tied embeddings check and passing
use_meta_tensor=True, the models loaded fine (although the policy and reference model had to be loaded on separate nodes to work). However, it then led to nan loss during training. - To fix the nan issue, I had to manually tie weights before sharding the weights in
fsdp_strategy.py:
module = model.model if is_wrapped else model
module.tie_weights()
full_state = module.state_dict()
apply_fsdp2(module, fsdp_kwargs, self.fsdp_config)
fsdp2_load_full_state_dict(module, full_state, cpu_offload)
fsdp_module = module
- After loading the weights on CPU only on rank 0, tying weights before sharding, and loading policy and reference models on separate nodes, fsdp2 works fine with the model.
Please let me know if the above config and changes for fsdp2 seem reasonable or if some other way might be more efficient. Thank you!