-
Notifications
You must be signed in to change notification settings - Fork 15
Open
Description
We are going ahead to quickly prototype a proof of concept for the pipeline proposal for DSv3 with DualPipeV. We have three sub-tasks to achieve this:
1 - Set-up basic pipeline infra that: @sanketpurandare
- (a) splits the model into logical stages
- (b) traces the join fwd+bwd for each stage
- (c) partitions the joint fwd+ bwd
- (d) stitches the fwd and bwd of two different stages to produces a naive multiplexed graph
- (e) Is able to run with the schedule described by PipelineIR and current torch.pipelining machinery
2 - Partition a graph with FSDP collectives into two sub-graphs: @IvanKobzarev
- (a) one that contains the FSDP collectives to produce unsharded params as outputs and
- (b) the other that contains the remaining ops that directly run on unsharded params as inputs.
3 - Overlapping comms and compute within the multiplexed fwd and bwd graph
- (a) Let the user annotate regions of code that appear as tags in the fx graph (Sherlock's PR). We only consider two distinct regions for now (Region A: Compute, Region B: Comms) @xmfan
- (b) Use the fx. node tags to reorder the nodes in the multiplexed graph @sanketpurandare
4 - Partition the bwd graph into two sub-graphs: @bdhirsh
- (a) bwd_dI: Backward graph that only computes the gradients with respect to inputs
- (b) bwd_dW: Backward graph that computes the gradients with respect to weights
We are going ahead by doing all four of these in parallel.
@bdhirsh is going to unblock any compile issues with AC/functionalization or dynamic shapes.
miladm, akoumpa, yanboliang, HuaizhengZhang and leeeizhang
Metadata
Metadata
Labels
No labels