Skip to content

[AutoParallel] support auto dp comm #72540

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: develop
Choose a base branch
from

Conversation

AndSonder
Copy link
Contributor

@AndSonder AndSonder commented Apr 28, 2025

PR Category

Auto Parallel

PR Types

New features

Description

当前在 wanx 这类视频生成模型上可能会出现不同卡上数据的 shape 不一致的情况,这种不均匀分布的情况目前自动并行场景下不支持,如果触发了 dp 维度的通信程序就会 hang 住,为了让自动并行能够适配 wanx 模型的 sharding 等训练策略,我们期望不引入 batch 维度的通信

想要 dp 维度不通信,那么我们在 forward + backward 这个阶段肯定 dp 维的切分状态都是 Replicate, 而且需要是 “Fake Replicate” (不同卡之间的数据并不一样) 这样 Fake Replicate 的数据可以通过下面的方法构造:

import paddle
import paddle.distributed as dist

mesh = dist.ProcessMesh([0, 1], dim_names=["dp"])

if dist.get_rank() == 0:
    a = paddle.ones([8])
else:
    a = paddle.zeros([8])

a = dist.auto_parallel.api.dtensor_from_local(a, mesh, [dist.Replicate()])

如何从非法切分状态转合法切分状态?

sharding stage1/stage2 sharding 相关的通信主要就是对 grad 的 reduce_scatter, 在 opt 之前 grad 首先会是 partial 状态,然后 partial=>shard 会触发 reduce_scatter

sharding stage3 的时候额外还多了 param 的通信,但是 param 的通信和其他的通信是相对独立的主要逻辑是在计算之前通过 all-gather 恢复成 Replicate 状态,在计算完成后再通过 slice 回到 shard 的状态

param 的通信我们只需要让 param 的 dp 维标记为 shard,借助切分推导就可以实现自动 all-gather

对于 grad 正常切分状态下,我们会用 Replicate 状态的 param 和 shard 状态的计算结果计算得到 partial 状态的 grad,然后对 grad 做 reduce_scatter

对于 grad 非正常切分状态,我们会用 Replicate 状态的 param 和 Fake Replicate 状态的计算结果计算得到 Fake Replicate 状态的 grad

这时候如果我们把 grad reshard 成 partial,replicate 到 partial 并不会引入通信,然后我们就会发现 grad 变成了一个合法的 partial 状态,后续就可以继续正常的做 reduce_scatter 转化为 shard 状态后输入到优化器 op 里面

具体实现

添加一个 FLAG,在打开 FLAG 的时候对 param 将 grad 从 reshard(非法) 到 partial(合法) 再从 partial 到 shard 即可兼容 sharding stage1/stage2/stage3

具体实现涉及两个部分,一部分是修改 shard_dataloader 的逻辑,一部分是修改 shard_optimizer 的逻辑

1、shard_dataloader
shard_dataloader 本来的逻辑是将输入的数据根据传入的 mesh 转化为 shard 状态,开启 flag 后需要修改这部分的逻辑为,本地 slice 后通过 dtensor_from local 接口转化为 Replicate 状态

2、shard_optimizer
在 _ShardOptimizer 的 apply_optimizer 中将 Replicate 状态的 grad 转化为 partial 状态

为了方便用户使用,添加 dist.enable_auto_dp() 的 api (内部就是设置了一个 flag),让普通用户只需要使用这个 api 就可以实现最简单的数据并行,不需要手动调 shard_dataloader

Pcard-76459

Copy link

paddle-bot bot commented Apr 28, 2025

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@codecov-commenter
Copy link

Codecov Report

Attention: Patch coverage is 79.45205% with 15 lines in your changes missing coverage. Please review.

Please upload report for BASE (develop@7a7f5cb). Learn more about missing BASE report.

Files with missing lines Patch % Lines
python/paddle/distributed/auto_parallel/api.py 77.94% 15 Missing ⚠️

❌ Your patch status has failed because the patch coverage (79.45%) is below the target coverage (90.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files
@@            Coverage Diff             @@
##             develop   #72540   +/-   ##
==========================================
  Coverage           ?   79.45%           
==========================================
  Files              ?        3           
  Lines              ?       73           
  Branches           ?        0           
==========================================
  Hits               ?       58           
  Misses             ?       15           
  Partials           ?        0           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@AndSonder AndSonder changed the title [AutoParallel] support manual dp comm [AutoParallel] support auto dp comm May 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants