-
Notifications
You must be signed in to change notification settings - Fork 651
Description
Bug description
Repro
import torch
import torch.nn as nn
import time
import os
import torch.distributed as dist
from simple_fsdp import data_parallel
class SimpleMLP(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.net = nn.Linear(input_dim, output_dim)
def forward(self, x):
return torch.matmul(x, self.net.weight.T) + self.net.bias + torch.matmul(x, self.net.weight.T)
input_dim, output_dim = 128, 10
def main():
rank = int(os.environ["RANK"])
local_rank = int(os.environ.get("LOCAL_RANK", rank))
world_size = int(os.environ["WORLD_SIZE"])
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
dist.init_process_group("nccl", rank=rank, world_size=world_size, device_id=device)
dp_mesh = dist.device_mesh.init_device_mesh(
device.type,
mesh_shape=(2,),
mesh_dim_names=("dp",),
)
module = SimpleMLP(input_dim, output_dim).to(device)
module = data_parallel(module, dp_mesh, mode="fully_shard")
module = torch.compile(module, mode="reduce-overhead")
x = torch.randn(32, input_dim, device=device)
output = module(x)
print(f"Output shape: {output.shape}")
if __name__ == "__main__":
main()
Problem
Currently, reusing weights in a module causes Simple FSDP to trigger multiple All-Gather operations for the same parameter(self.net.weight in the above example), duplicating AG/RS nodes in the FX graph and hurting performance.
Dynamo graph: duplicated torch__dynamo_variables_tensor_prim_redistribute calls
Dynamo graph
class GraphModule(torch.nn.Module):
def forward(self, L_self_modules_net_parameters_weight_: "f32[10, 128][128, 1]cuda:0", L_x_: "f32[32, 128][128, 1]cuda:0", L_self_modules_net_parameters_bias_: "f32[10][1]cuda:0"):
l_self_modules_net_parameters_weight_ = L_self_modules_net_parameters_weight_
l_x_ = L_x_
l_self_modules_net_parameters_bias_ = L_self_modules_net_parameters_bias_
# File: /mlx_devbox/users/yanbo.liang/playground/aether/aether/distributed/simple_fsdp.py:250 in replicate_compute, code: output = x.redistribute(
output: "f32[10, 128][128, 1]cuda:0" = torch__dynamo_variables_tensor_prim_redistribute(l_self_modules_net_parameters_weight_)
# File: /mlx_devbox/users/yanbo.liang/playground/aether/aether/distributed/simple_fsdp.py:257 in replicate_compute, code: output = output.to_local(grad_placements=self.grad_placements)
output_1: "f32[10, 128][128, 1]cuda:0" = torch__dynamo_variables_tensor_prim_to_local(output); output = None
# File: /mlx_devbox/users/yanbo.liang/playground/debug/debug1.py:16 in forward, code: return torch.matmul(x, self.net.weight.T) + self.net.bias + torch.matmul(x, self.net.weight.T)
getattr_1: "f32[128, 10][1, 128]cuda:0" = output_1.T; output_1 = None
matmul: "f32[32, 10][10, 1]cuda:0" = torch.matmul(l_x_, getattr_1); getattr_1 = None
# File: /mlx_devbox/users/yanbo.liang/playground/aether/aether/distributed/simple_fsdp.py:250 in replicate_compute, code: output = x.redistribute(
output_2: "f32[10][1]cuda:0" = torch__dynamo_variables_tensor_prim_redistribute_1(l_self_modules_net_parameters_bias_); l_self_modules_net_parameters_bias_ = None
# File: /mlx_devbox/users/yanbo.liang/playground/aether/aether/distributed/simple_fsdp.py:257 in replicate_compute, code: output = output.to_local(grad_placements=self.grad_placements)
output_3: "f32[10][1]cuda:0" = torch__dynamo_variables_tensor_prim_to_local_1(output_2); output_2 = None
# File: /mlx_devbox/users/yanbo.liang/playground/debug/debug1.py:16 in forward, code: return torch.matmul(x, self.net.weight.T) + self.net.bias + torch.matmul(x, self.net.weight.T)
add: "f32[32, 10][10, 1]cuda:0" = matmul + output_3; matmul = output_3 = None
# File: /mlx_devbox/users/yanbo.liang/playground/aether/aether/distributed/simple_fsdp.py:250 in replicate_compute, code: output = x.redistribute(
output_4: "f32[10, 128][128, 1]cuda:0" = torch__dynamo_variables_tensor_prim_redistribute_2(l_self_modules_net_parameters_weight_); l_self_modules_net_parameters_weight_ = None
# File: /mlx_devbox/users/yanbo.liang/playground/aether/aether/distributed/simple_fsdp.py:257 in replicate_compute, code: output = output.to_local(grad_placements=self.grad_placements)
output_5: "f32[10, 128][128, 1]cuda:0" = torch__dynamo_variables_tensor_prim_to_local_2(output_4); output_4 = None
# File: /mlx_devbox/users/yanbo.liang/playground/debug/debug1.py:16 in forward, code: return torch.matmul(x, self.net.weight.T) + self.net.bias + torch.matmul(x, self.net.weight.T)
getattr_2: "f32[128, 10][1, 128]cuda:0" = output_5.T; output_5 = None
matmul_1: "f32[32, 10][10, 1]cuda:0" = torch.matmul(l_x_, getattr_2); l_x_ = getattr_2 = None
add_1: "f32[32, 10][10, 1]cuda:0" = add + matmul_1; add = matmul_1 = None
return (add_1,)
Postgrad graph: No duplicated AG in the forward graph because of DCE removes it, but there are duplicated RS in the backward graph.
AOT graph
TRACED GRAPH
===== Forward graph 0 =====
/usr/local/lib/python3.11/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[5, 128][128, 1]cuda:0", primals_2: "f32[32, 128][128, 1]cuda:0", primals_3: "f32[5][1]cuda:0"):
# File: /mlx_devbox/users/yanbo.liang/playground/aether/aether/distributed/simple_fsdp.py:250 in replicate_compute, code: output = x.redistribute(
all_gather_into_tensor: "f32[10, 128][128, 1]cuda:0" = torch.ops._c10d_functional.all_gather_into_tensor.default(primals_1, 2, '0'); primals_1 = None
wait_tensor: "f32[10, 128][128, 1]cuda:0" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor); all_gather_into_tensor = None
# File: /mlx_devbox/users/yanbo.liang/playground/debug/debug1.py:16 in forward, code: return torch.matmul(x, self.net.weight.T) + self.net.bias + torch.matmul(x, self.net.weight.T)
permute: "f32[128, 10][1, 128]cuda:0" = torch.ops.aten.permute.default(wait_tensor, [1, 0]); wait_tensor = None
mm: "f32[32, 10][10, 1]cuda:0" = torch.ops.aten.mm.default(primals_2, permute); permute = None
# File: /mlx_devbox/users/yanbo.liang/playground/aether/aether/distributed/simple_fsdp.py:250 in replicate_compute, code: output = x.redistribute(
all_gather_into_tensor_1: "f32[10][1]cuda:0" = torch.ops._c10d_functional.all_gather_into_tensor.default(primals_3, 2, '0'); primals_3 = None
wait_tensor_1: "f32[10][1]cuda:0" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_1); all_gather_into_tensor_1 = None
# File: /mlx_devbox/users/yanbo.liang/playground/debug/debug1.py:16 in forward, code: return torch.matmul(x, self.net.weight.T) + self.net.bias + torch.matmul(x, self.net.weight.T)
add: "f32[32, 10][10, 1]cuda:0" = torch.ops.aten.add.Tensor(mm, wait_tensor_1); wait_tensor_1 = None
add_1: "f32[32, 10][10, 1]cuda:0" = torch.ops.aten.add.Tensor(add, mm); add = mm = None
return (add_1, primals_2)
TRACED GRAPH
===== Backward graph 0 =====
<eval_with_key>.2 class GraphModule(torch.nn.Module):
def forward(self, primals_2: "f32[32, 128][128, 1]cuda:0", tangents_1: "f32[32, 10][10, 1]cuda:0"):
# File: /mlx_devbox/users/yanbo.liang/playground/debug/debug1.py:16 in forward, code: return torch.matmul(x, self.net.weight.T) + self.net.bias + torch.matmul(x, self.net.weight.T)
permute_2: "f32[10, 32][1, 10]cuda:0" = torch.ops.aten.permute.default(tangents_1, [1, 0])
constant_pad_nd_default: "f32[12, 32][32, 1]cuda:0" = torch.ops.aten.constant_pad_nd.default(permute_2, [0, 0, 0, 2])
mm_default: "f32[12, 128][128, 1]cuda:0" = torch.ops.aten.mm.default(constant_pad_nd_default, primals_2); constant_pad_nd_default = None
slice_tensor: "f32[10, 128][128, 1]cuda:0" = torch.ops.aten.slice.Tensor(mm_default, 0, 0, -2); mm_default = None
# File: /mlx_devbox/users/yanbo.liang/playground/aether/aether/distributed/simple_fsdp.py:250 in replicate_compute, code: output = x.redistribute(
reduce_scatter_tensor: "f32[5, 128][128, 1]cuda:0" = torch.ops._c10d_functional.reduce_scatter_tensor.default(slice_tensor, 'avg', 2, '0'); slice_tensor = None
wait_tensor_3: "f32[5, 128][128, 1]cuda:0" = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor); reduce_scatter_tensor = None
# File: /mlx_devbox/users/yanbo.liang/playground/debug/debug1.py:16 in forward, code: return torch.matmul(x, self.net.weight.T) + self.net.bias + torch.matmul(x, self.net.weight.T)
sum_1: "f32[1, 10][10, 1]cuda:0" = torch.ops.aten.sum.dim_IntList(tangents_1, [0], True); tangents_1 = None
view_3: "f32[10][1]cuda:0" = torch.ops.aten.view.default(sum_1, [10]); sum_1 = None
# File: /mlx_devbox/users/yanbo.liang/playground/aether/aether/distributed/simple_fsdp.py:250 in replicate_compute, code: output = x.redistribute(
reduce_scatter_tensor_1: "f32[5][1]cuda:0" = torch.ops._c10d_functional.reduce_scatter_tensor.default(view_3, 'avg', 2, '0'); view_3 = None
wait_tensor_4: "f32[5][1]cuda:0" = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_1); reduce_scatter_tensor_1 = None
# File: /mlx_devbox/users/yanbo.liang/playground/debug/debug1.py:16 in forward, code: return torch.matmul(x, self.net.weight.T) + self.net.bias + torch.matmul(x, self.net.weight.T)
mm_3: "f32[10, 128][128, 1]cuda:0" = torch.ops.aten.mm.default(permute_2, primals_2); permute_2 = primals_2 = None
# File: /mlx_devbox/users/yanbo.liang/playground/aether/aether/distributed/simple_fsdp.py:250 in replicate_compute, code: output = x.redistribute(
reduce_scatter_tensor_2: "f32[5, 128][128, 1]cuda:0" = torch.ops._c10d_functional.reduce_scatter_tensor.default(mm_3, 'avg', 2, '0'); mm_3 = None
wait_tensor_5: "f32[5, 128][128, 1]cuda:0" = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_2); reduce_scatter_tensor_2 = None
# File: /mlx_devbox/users/yanbo.liang/playground/aether/aether/distributed/simple_fsdp.py:250 in replicate_compute, code: output = x.redistribute(
add_3: "f32[5, 128][128, 1]cuda:0" = torch.ops.aten.add.Tensor(wait_tensor_3, wait_tensor_5); wait_tensor_3 = wait_tensor_5 = None
return (add_3, None, wait_tensor_4)
Proposals
To fix this, we suggest adding a cache table in ReplicateComputation to track and reuse unsharded weights. I'm slightly unsure about composability edge cases, but if this sounds like the right approach, we are happy to open a PR.
Note: probably we need to distinguish the replicate_compute issues from different layers and only cache the ones from the same layer, because we have to issue AG per layer in the reshard_after_forward=True case.
Proposed pseudocode for ReplicateComputation.forward:
def forward(self, x: DTensor) -> torch.Tensor:
global _active_parametrization
# This should never be set to true during forward, only outside for model
# inspection / debugging / initialization
# model initialization can be done now through
# with disable_active_parametrization():
# model.init_weights()
if not _active_parametrization:
return x
param_id = id(x)
if param_id in self._fwd_cache:
return self._fwd_cache[param_id]
output = self.replicate_compute(x)
self._fwd_cache[param_id] = output
return output
New Dynamo graph: no duplicated torch__dynamo_variables_tensor_prim_redistribute.
Dynamo graph
class GraphModule(torch.nn.Module):
def forward(self, L_self_modules_net_parameters_weight_: "f32[10, 128][128, 1]cuda:0", L_x_: "f32[32, 128][128, 1]cuda:0", L_self_modules_net_parameters_bias_: "f32[10][1]cuda:0"):
l_self_modules_net_parameters_weight_ = L_self_modules_net_parameters_weight_
l_x_ = L_x_
l_self_modules_net_parameters_bias_ = L_self_modules_net_parameters_bias_
# File: /mlx_devbox/users/yanbo.liang/playground/debug/simple_fsdp.py:251 in replicate_compute, code: output = x.redistribute(
output: "f32[10, 128][128, 1]cuda:0" = torch__dynamo_variables_tensor_prim_redistribute(l_self_modules_net_parameters_weight_); l_self_modules_net_parameters_weight_ = None
# File: /mlx_devbox/users/yanbo.liang/playground/debug/simple_fsdp.py:258 in replicate_compute, code: output = output.to_local(grad_placements=self.grad_placements)
output_1: "f32[10, 128][128, 1]cuda:0" = torch__dynamo_variables_tensor_prim_to_local(output); output = None
# File: /mlx_devbox/users/yanbo.liang/playground/debug/debug1.py:16 in forward, code: return torch.matmul(x, self.net.weight.T) + self.net.bias + torch.matmul(x, self.net.weight.T)
getattr_1: "f32[128, 10][1, 128]cuda:0" = output_1.T
matmul: "f32[32, 10][10, 1]cuda:0" = torch.matmul(l_x_, getattr_1); getattr_1 = None
# File: /mlx_devbox/users/yanbo.liang/playground/debug/simple_fsdp.py:251 in replicate_compute, code: output = x.redistribute(
output_2: "f32[10][1]cuda:0" = torch__dynamo_variables_tensor_prim_redistribute_1(l_self_modules_net_parameters_bias_); l_self_modules_net_parameters_bias_ = None
# File: /mlx_devbox/users/yanbo.liang/playground/debug/simple_fsdp.py:258 in replicate_compute, code: output = output.to_local(grad_placements=self.grad_placements)
output_3: "f32[10][1]cuda:0" = torch__dynamo_variables_tensor_prim_to_local_1(output_2); output_2 = None
# File: /mlx_devbox/users/yanbo.liang/playground/debug/debug1.py:16 in forward, code: return torch.matmul(x, self.net.weight.T) + self.net.bias + torch.matmul(x, self.net.weight.T)
add: "f32[32, 10][10, 1]cuda:0" = matmul + output_3; matmul = None
getattr_2: "f32[128, 10][1, 128]cuda:0" = output_1.T
matmul_1: "f32[32, 10][10, 1]cuda:0" = torch.matmul(l_x_, getattr_2); l_x_ = getattr_2 = None
add_1: "f32[32, 10][10, 1]cuda:0" = add + matmul_1; add = matmul_1 = None
return (add_1, output_1, output_3)
New postgrad graph: No duplicated AG/RS in fwd and bwd graphs.
AOT graph
TRACED GRAPH
===== Forward graph 0 =====
/usr/local/lib/python3.11/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[5, 128][128, 1]cuda:0", primals_2: "f32[32, 128][128, 1]cuda:0", primals_3: "f32[5][1]cuda:0"):
# File: /mlx_devbox/users/yanbo.liang/playground/debug/simple_fsdp.py:251 in replicate_compute, code: output = x.redistribute(
all_gather_into_tensor: "f32[10, 128][128, 1]cuda:0" = torch.ops._c10d_functional.all_gather_into_tensor.default(primals_1, 2, '0'); primals_1 = None
wait_tensor: "f32[10, 128][128, 1]cuda:0" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor); all_gather_into_tensor = None
# File: /mlx_devbox/users/yanbo.liang/playground/debug/debug1.py:16 in forward, code: return torch.matmul(x, self.net.weight.T) + self.net.bias + torch.matmul(x, self.net.weight.T)
permute: "f32[128, 10][1, 128]cuda:0" = torch.ops.aten.permute.default(wait_tensor, [1, 0])
mm: "f32[32, 10][10, 1]cuda:0" = torch.ops.aten.mm.default(primals_2, permute); permute = None
# File: /mlx_devbox/users/yanbo.liang/playground/debug/simple_fsdp.py:251 in replicate_compute, code: output = x.redistribute(
all_gather_into_tensor_1: "f32[10][1]cuda:0" = torch.ops._c10d_functional.all_gather_into_tensor.default(primals_3, 2, '0'); primals_3 = None
wait_tensor_1: "f32[10][1]cuda:0" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_1); all_gather_into_tensor_1 = None
# File: /mlx_devbox/users/yanbo.liang/playground/debug/debug1.py:16 in forward, code: return torch.matmul(x, self.net.weight.T) + self.net.bias + torch.matmul(x, self.net.weight.T)
add: "f32[32, 10][10, 1]cuda:0" = torch.ops.aten.add.Tensor(mm, wait_tensor_1); mm = None
permute_1: "f32[128, 10][1, 128]cuda:0" = torch.ops.aten.permute.default(wait_tensor, [1, 0])
mm_1: "f32[32, 10][10, 1]cuda:0" = torch.ops.aten.mm.default(primals_2, permute_1); permute_1 = None
add_1: "f32[32, 10][10, 1]cuda:0" = torch.ops.aten.add.Tensor(add, mm_1); add = mm_1 = None
return (add_1, wait_tensor, wait_tensor_1, primals_2)
TRACED GRAPH
===== Backward graph 0 =====
<eval_with_key>.2 class GraphModule(torch.nn.Module):
def forward(self, primals_2: "f32[32, 128][128, 1]cuda:0", tangents_1: "f32[32, 10][10, 1]cuda:0", tangents_2: "f32[10, 128][128, 1]cuda:0", tangents_3: "f32[10][1]cuda:0"):
# File: /mlx_devbox/users/yanbo.liang/playground/debug/debug1.py:16 in forward, code: return torch.matmul(x, self.net.weight.T) + self.net.bias + torch.matmul(x, self.net.weight.T)
permute_2: "f32[10, 32][1, 10]cuda:0" = torch.ops.aten.permute.default(tangents_1, [1, 0])
constant_pad_nd_default_1: "f32[12, 32][32, 1]cuda:0" = torch.ops.aten.constant_pad_nd.default(permute_2, [0, 0, 0, 2]); permute_2 = None
mm_default_1: "f32[12, 128][128, 1]cuda:0" = torch.ops.aten.mm.default(constant_pad_nd_default_1, primals_2); constant_pad_nd_default_1 = primals_2 = None
slice_tensor_1: "f32[10, 128][128, 1]cuda:0" = torch.ops.aten.slice.Tensor(mm_default_1, 0, 0, -2); mm_default_1 = None
# File: /mlx_devbox/users/yanbo.liang/playground/debug/debug1.py:16 in forward, code: return torch.matmul(x, self.net.weight.T) + self.net.bias + torch.matmul(x, self.net.weight.T)
add_2: "f32[10, 128][128, 1]cuda:0" = torch.ops.aten.add.Tensor(tangents_2, slice_tensor_1); tangents_2 = None
# File: /mlx_devbox/users/yanbo.liang/playground/debug/debug1.py:16 in forward, code: return torch.matmul(x, self.net.weight.T) + self.net.bias + torch.matmul(x, self.net.weight.T)
sum_1: "f32[1, 10][10, 1]cuda:0" = torch.ops.aten.sum.dim_IntList(tangents_1, [0], True); tangents_1 = None
view_2: "f32[10][1]cuda:0" = torch.ops.aten.view.default(sum_1, [10]); sum_1 = None
# File: /mlx_devbox/users/yanbo.liang/playground/debug/debug1.py:16 in forward, code: return torch.matmul(x, self.net.weight.T) + self.net.bias + torch.matmul(x, self.net.weight.T)
add_3: "f32[10][1]cuda:0" = torch.ops.aten.add.Tensor(tangents_3, view_2); tangents_3 = view_2 = None
# File: /mlx_devbox/users/yanbo.liang/playground/debug/simple_fsdp.py:251 in replicate_compute, code: output = x.redistribute(
reduce_scatter_tensor: "f32[5][1]cuda:0" = torch.ops._c10d_functional.reduce_scatter_tensor.default(add_3, 'avg', 2, '0'); add_3 = None
wait_tensor_2: "f32[5][1]cuda:0" = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor); reduce_scatter_tensor = None
# File: /mlx_devbox/users/yanbo.liang/playground/debug/debug1.py:16 in forward, code: return torch.matmul(x, self.net.weight.T) + self.net.bias + torch.matmul(x, self.net.weight.T)
add_4: "f32[10, 128][128, 1]cuda:0" = torch.ops.aten.add.Tensor(add_2, slice_tensor_1); add_2 = slice_tensor_1 = None
# File: /mlx_devbox/users/yanbo.liang/playground/debug/simple_fsdp.py:251 in replicate_compute, code: output = x.redistribute(
reduce_scatter_tensor_1: "f32[5, 128][128, 1]cuda:0" = torch.ops._c10d_functional.reduce_scatter_tensor.default(add_4, 'avg', 2, '0'); add_4 = None
wait_tensor_3: "f32[5, 128][128, 1]cuda:0" = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_1); reduce_scatter_tensor_1 = None
return (wait_tensor_3, None, wait_tensor_2)
Another idea is to add DCE logics as FX pass to remove the duplicated one, it's more flexible as some scenarios like weight tying between embedding layer and lm_head layer are necessary, we have to distinguish which case should allow duplicated AG and which should not.
This issue is discovered by @yyp0. cc @ruisizhang123 @tianyu-l
Versions
torch 2.9