-
Notifications
You must be signed in to change notification settings - Fork 526
Add FSDP support to TiledMLP by preventing premature resharding during the tiled backward recompute loop. #1128
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 5 commits
cc3aa30
e115ae5
3343eee
8ea9c7d
600952b
c216da7
4208e85
1242c5b
16276d7
296d8b0
f21b83a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,6 +10,9 @@ | |
| from liger_kernel.transformers.tiled_mlp import LigerTiledSwiGLUMLP | ||
| from liger_kernel.utils import infer_device | ||
|
|
||
| import tempfile | ||
| import torch.multiprocessing as mp | ||
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | ||
| device = infer_device() | ||
|
|
||
|
|
||
|
|
@@ -195,3 +198,90 @@ def test_tiled_swiglu_correctness( | |
| ) | ||
|
|
||
| torch.testing.assert_close(x1.grad, x2.grad, atol=atol, rtol=rtol, msg="Input gradients don't match") | ||
|
|
||
|
|
||
| def _test_fsdp_tiled_mlp(rank, world_size, bs, hidden_size, intermediate_size, num_shards, dtype, atol, rtol, file_name): | ||
| # Init process group | ||
| torch.distributed.init_process_group( | ||
| backend="nccl", | ||
| init_method=f"file://{file_name}", | ||
| rank=rank, | ||
| world_size=world_size, | ||
| ) | ||
| torch.cuda.set_device(rank) | ||
| device = f"cuda:{rank}" | ||
|
|
||
|
|
||
| config = LlamaConfig( | ||
| hidden_size=hidden_size, | ||
| intermediate_size=intermediate_size, | ||
| hidden_act="silu", | ||
| ) | ||
|
|
||
| # Seed for replication | ||
| torch.manual_seed(42) | ||
| G = torch.randn(intermediate_size, hidden_size, device=device, dtype=dtype) | ||
| U = torch.randn(intermediate_size, hidden_size, device=device, dtype=dtype) | ||
| D = torch.randn(hidden_size, intermediate_size, device=device, dtype=dtype) | ||
|
|
||
| model = LigerTiledSwiGLUMLP(config=config, num_shards=num_shards).to(device).to(dtype) | ||
| model.gate_proj.weight.data = G.clone() | ||
| model.up_proj.weight.data = U.clone() | ||
| model.down_proj.weight.data = D.clone() | ||
|
|
||
| # Wrap with FSDP | ||
| model = FSDP(model, use_orig_params=True) | ||
|
|
||
| # Reference: same weights, no FSDP | ||
| ref_model = LigerTiledSwiGLUMLP(config=config, num_shards=num_shards).to(device).to(dtype) | ||
| # Copy weights from FSDP model (need to gather first or init identically) | ||
| ref_model.gate_proj.weight.data = G.clone() | ||
| ref_model.up_proj.weight.data = U.clone() | ||
| ref_model.down_proj.weight.data = D.clone() | ||
|
|
||
| # Forward + backward | ||
| x = torch.randn(bs, hidden_size, device=device, dtype=dtype).requires_grad_(True) | ||
|
|
||
| out = model(x) | ||
| out.sum().backward() | ||
|
|
||
| ref_out = ref_model(x) | ||
| ref_out.sum().backward() | ||
|
|
||
| # Assert | ||
| torch.testing.assert_close(out, ref_out, atol=atol, rtol=rtol) | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. missing gradients comparison |
||
|
|
||
| @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="requires at least 2 GPUs") | ||
|
|
||
| @pytest.mark.parametrize("world_size", [2]) # extend to [2, 4, 8] on multi-GPU hosts | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Extend to [2, 4, 8] is fine. Just make sure they are skipped if gpus are not present.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okey, done |
||
|
|
||
| @pytest.mark.parametrize("num_shards", [1, 2, 4]) | ||
| @pytest.mark.parametrize( | ||
| "bs, hidden_size, intermediate_size", | ||
| [ | ||
| (2, 256, 512), | ||
| (2, 512, 1024), | ||
| (1, 128, 256) | ||
| ], | ||
| ) | ||
| @pytest.mark.parametrize( | ||
| "dtype, atol, rtol", | ||
| [ | ||
| (torch.float32, 1e-5, 1e-5), | ||
| pytest.param( | ||
| torch.bfloat16, | ||
| 1e-1, | ||
| 1e-1, | ||
| marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported"), | ||
| ), | ||
| ], | ||
| ) | ||
| def test_fsdp_tiled_swiglu(world_size, num_shards, bs, hidden_size, intermediate_size, dtype, atol, rtol): | ||
| with tempfile.NamedTemporaryFile() as f: | ||
| mp.spawn( | ||
| _test_fsdp_tiled_mlp, | ||
| args=(world_size, bs, hidden_size, intermediate_size, num_shards, dtype, atol, rtol, f.name), | ||
| nprocs=world_size, | ||
| join=True, | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test against torch implentation as reference
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you clarify on the torch implementation as reference? FSDP would be already from torch. You mean using an MLP from torch, wrapping with a FSDP, and checking it works as expected (Or MLP vs TiledMLP + FSDP)? Or comparing the FSDP1 with FSDP2? At the moment, I just compared the TiledMLP wrapped with FSDP agains the LigerSwiGLUMLP wrapped with FSDP.