Skip to content
Draft
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions src/liger_kernel/ops/tiled_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def backward(ctx, *grads) -> tuple:
incoming_grad = grads[0].view(-1, hidden_size)
x_grad = torch.zeros_like(x)

# initialize param grad accumulators
param_grads = {p: None for p in mlp_module.parameters()}

x_shards = list(torch.chunk(x, chunks=shards, dim=0))

for i, x_shard in enumerate(x_shards):
Expand All @@ -84,13 +87,29 @@ def backward(ctx, *grads) -> tuple:
# if seqlen is not exactly divisible by shards the last step will be shorter than shard_step
shard_step = x_shards[i].shape[0]
shard_offset = i * x_shards[0].shape[0]

x_shard.grad = x_grad.narrow(0, shard_offset, shard_step).view_as(x_shard)
incoming_grad_shard = incoming_grad.narrow(0, shard_offset, shard_step).view_as(x_shard)

with torch.enable_grad():
output = fn(mlp_module, x_shard)
torch.autograd.backward(output, incoming_grad_shard)
local_grads = torch.autograd.grad(
outputs=output,
inputs=[x_shard] + list(mlp_module.parameters()),
grad_outputs=incoming_grad_shard,
)

x_grad.narrow(0, shard_offset, shard_step).copy_(local_grads[0])

for p, g in zip(mlp_module.parameters(), local_grads[1:]):
if param_grads[p] is None:
param_grads[p] = g
else:
param_grads[p] += g

for p, g in param_grads.items():
if p.grad is None:
p.grad = g
else:
p.grad += g

# unflatten
x_grad = x_grad.view(x_shape_orig)
Expand Down
90 changes: 90 additions & 0 deletions test/transformers/test_tiled_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from liger_kernel.transformers.tiled_mlp import LigerTiledSwiGLUMLP
from liger_kernel.utils import infer_device

import tempfile
import torch.multiprocessing as mp
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
device = infer_device()


Expand Down Expand Up @@ -195,3 +198,90 @@ def test_tiled_swiglu_correctness(
)

torch.testing.assert_close(x1.grad, x2.grad, atol=atol, rtol=rtol, msg="Input gradients don't match")


def _test_fsdp_tiled_mlp(rank, world_size, bs, hidden_size, intermediate_size, num_shards, dtype, atol, rtol, file_name):
# Init process group
torch.distributed.init_process_group(
backend="nccl",
init_method=f"file://{file_name}",
rank=rank,
world_size=world_size,
)
torch.cuda.set_device(rank)
device = f"cuda:{rank}"


config = LlamaConfig(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
hidden_act="silu",
)

# Seed for replication
torch.manual_seed(42)
G = torch.randn(intermediate_size, hidden_size, device=device, dtype=dtype)
U = torch.randn(intermediate_size, hidden_size, device=device, dtype=dtype)
D = torch.randn(hidden_size, intermediate_size, device=device, dtype=dtype)

model = LigerTiledSwiGLUMLP(config=config, num_shards=num_shards).to(device).to(dtype)
model.gate_proj.weight.data = G.clone()
model.up_proj.weight.data = U.clone()
model.down_proj.weight.data = D.clone()

# Wrap with FSDP
model = FSDP(model, use_orig_params=True)

# Reference: same weights, no FSDP
ref_model = LigerTiledSwiGLUMLP(config=config, num_shards=num_shards).to(device).to(dtype)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test against torch implentation as reference

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you clarify on the torch implementation as reference? FSDP would be already from torch. You mean using an MLP from torch, wrapping with a FSDP, and checking it works as expected (Or MLP vs TiledMLP + FSDP)? Or comparing the FSDP1 with FSDP2? At the moment, I just compared the TiledMLP wrapped with FSDP agains the LigerSwiGLUMLP wrapped with FSDP.

# Copy weights from FSDP model (need to gather first or init identically)
ref_model.gate_proj.weight.data = G.clone()
ref_model.up_proj.weight.data = U.clone()
ref_model.down_proj.weight.data = D.clone()

# Forward + backward
x = torch.randn(bs, hidden_size, device=device, dtype=dtype).requires_grad_(True)

out = model(x)
out.sum().backward()

ref_out = ref_model(x)
ref_out.sum().backward()

# Assert
torch.testing.assert_close(out, ref_out, atol=atol, rtol=rtol)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing gradients comparison


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="requires at least 2 GPUs")

@pytest.mark.parametrize("world_size", [2]) # extend to [2, 4, 8] on multi-GPU hosts
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extend to [2, 4, 8] is fine. Just make sure they are skipped if gpus are not present.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okey, done


@pytest.mark.parametrize("num_shards", [1, 2, 4])
@pytest.mark.parametrize(
"bs, hidden_size, intermediate_size",
[
(2, 256, 512),
(2, 512, 1024),
(1, 128, 256)
],
)
@pytest.mark.parametrize(
"dtype, atol, rtol",
[
(torch.float32, 1e-5, 1e-5),
pytest.param(
torch.bfloat16,
1e-1,
1e-1,
marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported"),
),
],
)
def test_fsdp_tiled_swiglu(world_size, num_shards, bs, hidden_size, intermediate_size, dtype, atol, rtol):
with tempfile.NamedTemporaryFile() as f:
mp.spawn(
_test_fsdp_tiled_mlp,
args=(world_size, bs, hidden_size, intermediate_size, num_shards, dtype, atol, rtol, f.name),
nprocs=world_size,
join=True,
)