Skip to content

Compile-based Pipeline ParallelismΒ #199

@fmassa

Description

@fmassa

We are going ahead to quickly prototype a proof of concept for the pipeline proposal for DSv3 with DualPipeV. We have three sub-tasks to achieve this:

1 - Set-up basic pipeline infra that: @sanketpurandare

  • (a) splits the model into logical stages
  • (b) traces the join fwd+bwd for each stage
  • (c) partitions the joint fwd+ bwd
  • (d) stitches the fwd and bwd of two different stages to produces a naive multiplexed graph
  • (e) Is able to run with the schedule described by PipelineIR and current torch.pipelining machinery

2 - Partition a graph with FSDP collectives into two sub-graphs: @IvanKobzarev

  • (a) one that contains the FSDP collectives to produce unsharded params as outputs and
  • (b) the other that contains the remaining ops that directly run on unsharded params as inputs.

3 - Overlapping comms and compute within the multiplexed fwd and bwd graph

  • (a) Let the user annotate regions of code that appear as tags in the fx graph (Sherlock's PR). We only consider two distinct regions for now (Region A: Compute, Region B: Comms) @xmfan
  • (b) Use the fx. node tags to reorder the nodes in the multiplexed graph @sanketpurandare

4 - Partition the bwd graph into two sub-graphs: @bdhirsh

  • (a) bwd_dI: Backward graph that only computes the gradients with respect to inputs
  • (b) bwd_dW: Backward graph that computes the gradients with respect to weights

We are going ahead by doing all four of these in parallel.
@bdhirsh is going to unblock any compile issues with AC/functionalization or dynamic shapes.

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions