Skip to content

Duplicated AG/RS issued from Simple FSDP if there are multiple reference to the same parameters #2133

@yanboliang

Description

@yanboliang

Bug description

Repro


import torch
import torch.nn as nn
import time
import os
import torch.distributed as dist
from simple_fsdp import data_parallel


class SimpleMLP(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.net = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return torch.matmul(x, self.net.weight.T) + self.net.bias + torch.matmul(x, self.net.weight.T)

input_dim, output_dim = 128, 10


def main():
    rank = int(os.environ["RANK"])
    local_rank = int(os.environ.get("LOCAL_RANK", rank))
    world_size = int(os.environ["WORLD_SIZE"])
    device = torch.device(f"cuda:{local_rank}")
    torch.cuda.set_device(device)

    dist.init_process_group("nccl", rank=rank, world_size=world_size, device_id=device)

    dp_mesh = dist.device_mesh.init_device_mesh(
        device.type,
        mesh_shape=(2,),
        mesh_dim_names=("dp",),
    )

    module = SimpleMLP(input_dim, output_dim).to(device)
    module = data_parallel(module, dp_mesh, mode="fully_shard")
    module = torch.compile(module, mode="reduce-overhead")

    x = torch.randn(32, input_dim, device=device)
    output = module(x)
    print(f"Output shape: {output.shape}")


if __name__ == "__main__":
    main()

Problem


Currently, reusing weights in a module causes Simple FSDP to trigger multiple All-Gather operations for the same parameter(self.net.weight in the above example), duplicating AG/RS nodes in the FX graph and hurting performance.

Dynamo graph: duplicated torch__dynamo_variables_tensor_prim_redistribute calls

Dynamo graph
class GraphModule(torch.nn.Module):
    def forward(self, L_self_modules_net_parameters_weight_: "f32[10, 128][128, 1]cuda:0", L_x_: "f32[32, 128][128, 1]cuda:0", L_self_modules_net_parameters_bias_: "f32[10][1]cuda:0"):
        l_self_modules_net_parameters_weight_ = L_self_modules_net_parameters_weight_
        l_x_ = L_x_
        l_self_modules_net_parameters_bias_ = L_self_modules_net_parameters_bias_
        
        # File: /mlx_devbox/users/yanbo.liang/playground/aether/aether/distributed/simple_fsdp.py:250 in replicate_compute, code: output = x.redistribute(
        output: "f32[10, 128][128, 1]cuda:0" = torch__dynamo_variables_tensor_prim_redistribute(l_self_modules_net_parameters_weight_)
        
        # File: /mlx_devbox/users/yanbo.liang/playground/aether/aether/distributed/simple_fsdp.py:257 in replicate_compute, code: output = output.to_local(grad_placements=self.grad_placements)
        output_1: "f32[10, 128][128, 1]cuda:0" = torch__dynamo_variables_tensor_prim_to_local(output);  output = None
        
        # File: /mlx_devbox/users/yanbo.liang/playground/debug/debug1.py:16 in forward, code: return torch.matmul(x, self.net.weight.T) + self.net.bias + torch.matmul(x, self.net.weight.T)
        getattr_1: "f32[128, 10][1, 128]cuda:0" = output_1.T;  output_1 = None
        matmul: "f32[32, 10][10, 1]cuda:0" = torch.matmul(l_x_, getattr_1);  getattr_1 = None
        
        # File: /mlx_devbox/users/yanbo.liang/playground/aether/aether/distributed/simple_fsdp.py:250 in replicate_compute, code: output = x.redistribute(
        output_2: "f32[10][1]cuda:0" = torch__dynamo_variables_tensor_prim_redistribute_1(l_self_modules_net_parameters_bias_);  l_self_modules_net_parameters_bias_ = None
        
        # File: /mlx_devbox/users/yanbo.liang/playground/aether/aether/distributed/simple_fsdp.py:257 in replicate_compute, code: output = output.to_local(grad_placements=self.grad_placements)
        output_3: "f32[10][1]cuda:0" = torch__dynamo_variables_tensor_prim_to_local_1(output_2);  output_2 = None
        
        # File: /mlx_devbox/users/yanbo.liang/playground/debug/debug1.py:16 in forward, code: return torch.matmul(x, self.net.weight.T) + self.net.bias + torch.matmul(x, self.net.weight.T)
        add: "f32[32, 10][10, 1]cuda:0" = matmul + output_3;  matmul = output_3 = None
        
        # File: /mlx_devbox/users/yanbo.liang/playground/aether/aether/distributed/simple_fsdp.py:250 in replicate_compute, code: output = x.redistribute(
        output_4: "f32[10, 128][128, 1]cuda:0" = torch__dynamo_variables_tensor_prim_redistribute_2(l_self_modules_net_parameters_weight_);  l_self_modules_net_parameters_weight_ = None
        
        # File: /mlx_devbox/users/yanbo.liang/playground/aether/aether/distributed/simple_fsdp.py:257 in replicate_compute, code: output = output.to_local(grad_placements=self.grad_placements)
        output_5: "f32[10, 128][128, 1]cuda:0" = torch__dynamo_variables_tensor_prim_to_local_2(output_4);  output_4 = None
        
        # File: /mlx_devbox/users/yanbo.liang/playground/debug/debug1.py:16 in forward, code: return torch.matmul(x, self.net.weight.T) + self.net.bias + torch.matmul(x, self.net.weight.T)
        getattr_2: "f32[128, 10][1, 128]cuda:0" = output_5.T;  output_5 = None
        matmul_1: "f32[32, 10][10, 1]cuda:0" = torch.matmul(l_x_, getattr_2);  l_x_ = getattr_2 = None
        add_1: "f32[32, 10][10, 1]cuda:0" = add + matmul_1;  add = matmul_1 = None
        return (add_1,)

Postgrad graph: No duplicated AG in the forward graph because of DCE removes it, but there are duplicated RS in the backward graph.

AOT graph
 TRACED GRAPH
  ===== Forward graph 0 =====
  /usr/local/lib/python3.11/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
     def forward(self, primals_1: "f32[5, 128][128, 1]cuda:0", primals_2: "f32[32, 128][128, 1]cuda:0", primals_3: "f32[5][1]cuda:0"):
          # File: /mlx_devbox/users/yanbo.liang/playground/aether/aether/distributed/simple_fsdp.py:250 in replicate_compute, code: output = x.redistribute(
         all_gather_into_tensor: "f32[10, 128][128, 1]cuda:0" = torch.ops._c10d_functional.all_gather_into_tensor.default(primals_1, 2, '0');  primals_1 = None
         wait_tensor: "f32[10, 128][128, 1]cuda:0" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor);  all_gather_into_tensor = None
         
          # File: /mlx_devbox/users/yanbo.liang/playground/debug/debug1.py:16 in forward, code: return torch.matmul(x, self.net.weight.T) + self.net.bias + torch.matmul(x, self.net.weight.T)
         permute: "f32[128, 10][1, 128]cuda:0" = torch.ops.aten.permute.default(wait_tensor, [1, 0]);  wait_tensor = None
         mm: "f32[32, 10][10, 1]cuda:0" = torch.ops.aten.mm.default(primals_2, permute);  permute = None
         
          # File: /mlx_devbox/users/yanbo.liang/playground/aether/aether/distributed/simple_fsdp.py:250 in replicate_compute, code: output = x.redistribute(
         all_gather_into_tensor_1: "f32[10][1]cuda:0" = torch.ops._c10d_functional.all_gather_into_tensor.default(primals_3, 2, '0');  primals_3 = None
         wait_tensor_1: "f32[10][1]cuda:0" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_1);  all_gather_into_tensor_1 = None
         
          # File: /mlx_devbox/users/yanbo.liang/playground/debug/debug1.py:16 in forward, code: return torch.matmul(x, self.net.weight.T) + self.net.bias + torch.matmul(x, self.net.weight.T)
         add: "f32[32, 10][10, 1]cuda:0" = torch.ops.aten.add.Tensor(mm, wait_tensor_1);  wait_tensor_1 = None
         add_1: "f32[32, 10][10, 1]cuda:0" = torch.ops.aten.add.Tensor(add, mm);  add = mm = None
         return (add_1, primals_2)
         
 
 TRACED GRAPH
  ===== Backward graph 0 =====
  <eval_with_key>.2 class GraphModule(torch.nn.Module):
     def forward(self, primals_2: "f32[32, 128][128, 1]cuda:0", tangents_1: "f32[32, 10][10, 1]cuda:0"):
          # File: /mlx_devbox/users/yanbo.liang/playground/debug/debug1.py:16 in forward, code: return torch.matmul(x, self.net.weight.T) + self.net.bias + torch.matmul(x, self.net.weight.T)
         permute_2: "f32[10, 32][1, 10]cuda:0" = torch.ops.aten.permute.default(tangents_1, [1, 0])
         constant_pad_nd_default: "f32[12, 32][32, 1]cuda:0" = torch.ops.aten.constant_pad_nd.default(permute_2, [0, 0, 0, 2])
         mm_default: "f32[12, 128][128, 1]cuda:0" = torch.ops.aten.mm.default(constant_pad_nd_default, primals_2);  constant_pad_nd_default = None
         slice_tensor: "f32[10, 128][128, 1]cuda:0" = torch.ops.aten.slice.Tensor(mm_default, 0, 0, -2);  mm_default = None
         
          # File: /mlx_devbox/users/yanbo.liang/playground/aether/aether/distributed/simple_fsdp.py:250 in replicate_compute, code: output = x.redistribute(
         reduce_scatter_tensor: "f32[5, 128][128, 1]cuda:0" = torch.ops._c10d_functional.reduce_scatter_tensor.default(slice_tensor, 'avg', 2, '0');  slice_tensor = None
         wait_tensor_3: "f32[5, 128][128, 1]cuda:0" = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor);  reduce_scatter_tensor = None
         
          # File: /mlx_devbox/users/yanbo.liang/playground/debug/debug1.py:16 in forward, code: return torch.matmul(x, self.net.weight.T) + self.net.bias + torch.matmul(x, self.net.weight.T)
         sum_1: "f32[1, 10][10, 1]cuda:0" = torch.ops.aten.sum.dim_IntList(tangents_1, [0], True);  tangents_1 = None
         view_3: "f32[10][1]cuda:0" = torch.ops.aten.view.default(sum_1, [10]);  sum_1 = None
         
          # File: /mlx_devbox/users/yanbo.liang/playground/aether/aether/distributed/simple_fsdp.py:250 in replicate_compute, code: output = x.redistribute(
         reduce_scatter_tensor_1: "f32[5][1]cuda:0" = torch.ops._c10d_functional.reduce_scatter_tensor.default(view_3, 'avg', 2, '0');  view_3 = None
         wait_tensor_4: "f32[5][1]cuda:0" = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_1);  reduce_scatter_tensor_1 = None
         
          # File: /mlx_devbox/users/yanbo.liang/playground/debug/debug1.py:16 in forward, code: return torch.matmul(x, self.net.weight.T) + self.net.bias + torch.matmul(x, self.net.weight.T)
         mm_3: "f32[10, 128][128, 1]cuda:0" = torch.ops.aten.mm.default(permute_2, primals_2);  permute_2 = primals_2 = None
         
          # File: /mlx_devbox/users/yanbo.liang/playground/aether/aether/distributed/simple_fsdp.py:250 in replicate_compute, code: output = x.redistribute(
         reduce_scatter_tensor_2: "f32[5, 128][128, 1]cuda:0" = torch.ops._c10d_functional.reduce_scatter_tensor.default(mm_3, 'avg', 2, '0');  mm_3 = None
         wait_tensor_5: "f32[5, 128][128, 1]cuda:0" = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_2);  reduce_scatter_tensor_2 = None
         
          # File: /mlx_devbox/users/yanbo.liang/playground/aether/aether/distributed/simple_fsdp.py:250 in replicate_compute, code: output = x.redistribute(
         add_3: "f32[5, 128][128, 1]cuda:0" = torch.ops.aten.add.Tensor(wait_tensor_3, wait_tensor_5);  wait_tensor_3 = wait_tensor_5 = None
         return (add_3, None, wait_tensor_4)

Proposals


To fix this, we suggest adding a cache table in ReplicateComputation to track and reuse unsharded weights. I'm slightly unsure about composability edge cases, but if this sounds like the right approach, we are happy to open a PR.
Note: probably we need to distinguish the replicate_compute issues from different layers and only cache the ones from the same layer, because we have to issue AG per layer in the reshard_after_forward=True case.
Proposed pseudocode for ReplicateComputation.forward:

    def forward(self, x: DTensor) -> torch.Tensor:
        global _active_parametrization
        # This should never be set to true during forward, only outside for model
        # inspection / debugging / initialization
        # model initialization can be done now through
        # with disable_active_parametrization():
        #     model.init_weights()
        if not _active_parametrization:
            return x
        
        param_id = id(x)
        if param_id in self._fwd_cache:
            return self._fwd_cache[param_id]

        output = self.replicate_compute(x)

        self._fwd_cache[param_id] = output

        return output

New Dynamo graph: no duplicated torch__dynamo_variables_tensor_prim_redistribute.

Dynamo graph
class GraphModule(torch.nn.Module):
    def forward(self, L_self_modules_net_parameters_weight_: "f32[10, 128][128, 1]cuda:0", L_x_: "f32[32, 128][128, 1]cuda:0", L_self_modules_net_parameters_bias_: "f32[10][1]cuda:0"):
        l_self_modules_net_parameters_weight_ = L_self_modules_net_parameters_weight_
        l_x_ = L_x_
        l_self_modules_net_parameters_bias_ = L_self_modules_net_parameters_bias_
        
        # File: /mlx_devbox/users/yanbo.liang/playground/debug/simple_fsdp.py:251 in replicate_compute, code: output = x.redistribute(
        output: "f32[10, 128][128, 1]cuda:0" = torch__dynamo_variables_tensor_prim_redistribute(l_self_modules_net_parameters_weight_);  l_self_modules_net_parameters_weight_ = None
        
        # File: /mlx_devbox/users/yanbo.liang/playground/debug/simple_fsdp.py:258 in replicate_compute, code: output = output.to_local(grad_placements=self.grad_placements)
        output_1: "f32[10, 128][128, 1]cuda:0" = torch__dynamo_variables_tensor_prim_to_local(output);  output = None
        
        # File: /mlx_devbox/users/yanbo.liang/playground/debug/debug1.py:16 in forward, code: return torch.matmul(x, self.net.weight.T) + self.net.bias + torch.matmul(x, self.net.weight.T)
        getattr_1: "f32[128, 10][1, 128]cuda:0" = output_1.T
        matmul: "f32[32, 10][10, 1]cuda:0" = torch.matmul(l_x_, getattr_1);  getattr_1 = None
        
        # File: /mlx_devbox/users/yanbo.liang/playground/debug/simple_fsdp.py:251 in replicate_compute, code: output = x.redistribute(
        output_2: "f32[10][1]cuda:0" = torch__dynamo_variables_tensor_prim_redistribute_1(l_self_modules_net_parameters_bias_);  l_self_modules_net_parameters_bias_ = None
        
        # File: /mlx_devbox/users/yanbo.liang/playground/debug/simple_fsdp.py:258 in replicate_compute, code: output = output.to_local(grad_placements=self.grad_placements)
        output_3: "f32[10][1]cuda:0" = torch__dynamo_variables_tensor_prim_to_local_1(output_2);  output_2 = None
        
        # File: /mlx_devbox/users/yanbo.liang/playground/debug/debug1.py:16 in forward, code: return torch.matmul(x, self.net.weight.T) + self.net.bias + torch.matmul(x, self.net.weight.T)
        add: "f32[32, 10][10, 1]cuda:0" = matmul + output_3;  matmul = None
        getattr_2: "f32[128, 10][1, 128]cuda:0" = output_1.T
        matmul_1: "f32[32, 10][10, 1]cuda:0" = torch.matmul(l_x_, getattr_2);  l_x_ = getattr_2 = None
        add_1: "f32[32, 10][10, 1]cuda:0" = add + matmul_1;  add = matmul_1 = None
        return (add_1, output_1, output_3)

New postgrad graph: No duplicated AG/RS in fwd and bwd graphs.

AOT graph
 TRACED GRAPH
  ===== Forward graph 0 =====
  /usr/local/lib/python3.11/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
     def forward(self, primals_1: "f32[5, 128][128, 1]cuda:0", primals_2: "f32[32, 128][128, 1]cuda:0", primals_3: "f32[5][1]cuda:0"):
          # File: /mlx_devbox/users/yanbo.liang/playground/debug/simple_fsdp.py:251 in replicate_compute, code: output = x.redistribute(
         all_gather_into_tensor: "f32[10, 128][128, 1]cuda:0" = torch.ops._c10d_functional.all_gather_into_tensor.default(primals_1, 2, '0');  primals_1 = None
         wait_tensor: "f32[10, 128][128, 1]cuda:0" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor);  all_gather_into_tensor = None
         
          # File: /mlx_devbox/users/yanbo.liang/playground/debug/debug1.py:16 in forward, code: return torch.matmul(x, self.net.weight.T) + self.net.bias + torch.matmul(x, self.net.weight.T)
         permute: "f32[128, 10][1, 128]cuda:0" = torch.ops.aten.permute.default(wait_tensor, [1, 0])
         mm: "f32[32, 10][10, 1]cuda:0" = torch.ops.aten.mm.default(primals_2, permute);  permute = None
         
          # File: /mlx_devbox/users/yanbo.liang/playground/debug/simple_fsdp.py:251 in replicate_compute, code: output = x.redistribute(
         all_gather_into_tensor_1: "f32[10][1]cuda:0" = torch.ops._c10d_functional.all_gather_into_tensor.default(primals_3, 2, '0');  primals_3 = None
         wait_tensor_1: "f32[10][1]cuda:0" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_1);  all_gather_into_tensor_1 = None
         
          # File: /mlx_devbox/users/yanbo.liang/playground/debug/debug1.py:16 in forward, code: return torch.matmul(x, self.net.weight.T) + self.net.bias + torch.matmul(x, self.net.weight.T)
         add: "f32[32, 10][10, 1]cuda:0" = torch.ops.aten.add.Tensor(mm, wait_tensor_1);  mm = None
         permute_1: "f32[128, 10][1, 128]cuda:0" = torch.ops.aten.permute.default(wait_tensor, [1, 0])
         mm_1: "f32[32, 10][10, 1]cuda:0" = torch.ops.aten.mm.default(primals_2, permute_1);  permute_1 = None
         add_1: "f32[32, 10][10, 1]cuda:0" = torch.ops.aten.add.Tensor(add, mm_1);  add = mm_1 = None
         return (add_1, wait_tensor, wait_tensor_1, primals_2)
         
 
 TRACED GRAPH
  ===== Backward graph 0 =====
  <eval_with_key>.2 class GraphModule(torch.nn.Module):
     def forward(self, primals_2: "f32[32, 128][128, 1]cuda:0", tangents_1: "f32[32, 10][10, 1]cuda:0", tangents_2: "f32[10, 128][128, 1]cuda:0", tangents_3: "f32[10][1]cuda:0"):
          # File: /mlx_devbox/users/yanbo.liang/playground/debug/debug1.py:16 in forward, code: return torch.matmul(x, self.net.weight.T) + self.net.bias + torch.matmul(x, self.net.weight.T)
         permute_2: "f32[10, 32][1, 10]cuda:0" = torch.ops.aten.permute.default(tangents_1, [1, 0])
         constant_pad_nd_default_1: "f32[12, 32][32, 1]cuda:0" = torch.ops.aten.constant_pad_nd.default(permute_2, [0, 0, 0, 2]);  permute_2 = None
         mm_default_1: "f32[12, 128][128, 1]cuda:0" = torch.ops.aten.mm.default(constant_pad_nd_default_1, primals_2);  constant_pad_nd_default_1 = primals_2 = None
         slice_tensor_1: "f32[10, 128][128, 1]cuda:0" = torch.ops.aten.slice.Tensor(mm_default_1, 0, 0, -2);  mm_default_1 = None
         
          # File: /mlx_devbox/users/yanbo.liang/playground/debug/debug1.py:16 in forward, code: return torch.matmul(x, self.net.weight.T) + self.net.bias + torch.matmul(x, self.net.weight.T)
         add_2: "f32[10, 128][128, 1]cuda:0" = torch.ops.aten.add.Tensor(tangents_2, slice_tensor_1);  tangents_2 = None
         
          # File: /mlx_devbox/users/yanbo.liang/playground/debug/debug1.py:16 in forward, code: return torch.matmul(x, self.net.weight.T) + self.net.bias + torch.matmul(x, self.net.weight.T)
         sum_1: "f32[1, 10][10, 1]cuda:0" = torch.ops.aten.sum.dim_IntList(tangents_1, [0], True);  tangents_1 = None
         view_2: "f32[10][1]cuda:0" = torch.ops.aten.view.default(sum_1, [10]);  sum_1 = None
         
          # File: /mlx_devbox/users/yanbo.liang/playground/debug/debug1.py:16 in forward, code: return torch.matmul(x, self.net.weight.T) + self.net.bias + torch.matmul(x, self.net.weight.T)
         add_3: "f32[10][1]cuda:0" = torch.ops.aten.add.Tensor(tangents_3, view_2);  tangents_3 = view_2 = None
         
          # File: /mlx_devbox/users/yanbo.liang/playground/debug/simple_fsdp.py:251 in replicate_compute, code: output = x.redistribute(
         reduce_scatter_tensor: "f32[5][1]cuda:0" = torch.ops._c10d_functional.reduce_scatter_tensor.default(add_3, 'avg', 2, '0');  add_3 = None
         wait_tensor_2: "f32[5][1]cuda:0" = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor);  reduce_scatter_tensor = None
         
          # File: /mlx_devbox/users/yanbo.liang/playground/debug/debug1.py:16 in forward, code: return torch.matmul(x, self.net.weight.T) + self.net.bias + torch.matmul(x, self.net.weight.T)
         add_4: "f32[10, 128][128, 1]cuda:0" = torch.ops.aten.add.Tensor(add_2, slice_tensor_1);  add_2 = slice_tensor_1 = None
         
          # File: /mlx_devbox/users/yanbo.liang/playground/debug/simple_fsdp.py:251 in replicate_compute, code: output = x.redistribute(
         reduce_scatter_tensor_1: "f32[5, 128][128, 1]cuda:0" = torch.ops._c10d_functional.reduce_scatter_tensor.default(add_4, 'avg', 2, '0');  add_4 = None
         wait_tensor_3: "f32[5, 128][128, 1]cuda:0" = torch.ops._c10d_functional.wait_tensor.default(reduce_scatter_tensor_1);  reduce_scatter_tensor_1 = None
         return (wait_tensor_3, None, wait_tensor_2)

Another idea is to add DCE logics as FX pass to remove the duplicated one, it's more flexible as some scenarios like weight tying between embedding layer and lm_head layer are necessary, we have to distinguish which case should allow duplicated AG and which should not.

This issue is discovered by @yyp0. cc @ruisizhang123 @tianyu-l

Versions

torch 2.9

Metadata

Metadata

Assignees

No one assigned

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions