-
Notifications
You must be signed in to change notification settings - Fork 5.7k
[AutoParallel] support auto dp comm #72540
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
Codecov ReportAttention: Patch coverage is
❌ Your patch status has failed because the patch coverage (79.45%) is below the target coverage (90.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## develop #72540 +/- ##
==========================================
Coverage ? 79.45%
==========================================
Files ? 3
Lines ? 73
Branches ? 0
==========================================
Hits ? 58
Misses ? 15
Partials ? 0 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
PR Category
Auto Parallel
PR Types
New features
Description
当前在 wanx 这类视频生成模型上可能会出现不同卡上数据的 shape 不一致的情况,这种不均匀分布的情况目前自动并行场景下不支持,如果触发了 dp 维度的通信程序就会 hang 住,为了让自动并行能够适配 wanx 模型的 sharding 等训练策略,我们期望不引入 batch 维度的通信
想要 dp 维度不通信,那么我们在 forward + backward 这个阶段肯定 dp 维的切分状态都是 Replicate, 而且需要是 “Fake Replicate” (不同卡之间的数据并不一样) 这样 Fake Replicate 的数据可以通过下面的方法构造:
如何从非法切分状态转合法切分状态?
sharding stage1/stage2 sharding 相关的通信主要就是对 grad 的 reduce_scatter, 在 opt 之前 grad 首先会是 partial 状态,然后 partial=>shard 会触发 reduce_scatter
sharding stage3 的时候额外还多了 param 的通信,但是 param 的通信和其他的通信是相对独立的主要逻辑是在计算之前通过 all-gather 恢复成 Replicate 状态,在计算完成后再通过 slice 回到 shard 的状态
param 的通信我们只需要让 param 的 dp 维标记为 shard,借助切分推导就可以实现自动 all-gather
对于 grad 正常切分状态下,我们会用 Replicate 状态的 param 和 shard 状态的计算结果计算得到 partial 状态的 grad,然后对 grad 做 reduce_scatter
对于 grad 非正常切分状态,我们会用 Replicate 状态的 param 和 Fake Replicate 状态的计算结果计算得到 Fake Replicate 状态的 grad
这时候如果我们把 grad reshard 成 partial,replicate 到 partial 并不会引入通信,然后我们就会发现 grad 变成了一个合法的 partial 状态,后续就可以继续正常的做 reduce_scatter 转化为 shard 状态后输入到优化器 op 里面
具体实现
添加一个 FLAG,在打开 FLAG 的时候对 param 将 grad 从 reshard(非法) 到 partial(合法) 再从 partial 到 shard 即可兼容 sharding stage1/stage2/stage3
具体实现涉及两个部分,一部分是修改 shard_dataloader 的逻辑,一部分是修改 shard_optimizer 的逻辑
1、shard_dataloader
shard_dataloader 本来的逻辑是将输入的数据根据传入的 mesh 转化为 shard 状态,开启 flag 后需要修改这部分的逻辑为,本地 slice 后通过 dtensor_from local 接口转化为 Replicate 状态
2、shard_optimizer
在 _ShardOptimizer 的 apply_optimizer 中将 Replicate 状态的 grad 转化为 partial 状态
为了方便用户使用,添加 dist.enable_auto_dp() 的 api (内部就是设置了一个 flag),让普通用户只需要使用这个 api 就可以实现最简单的数据并行,不需要手动调 shard_dataloader
Pcard-76459