-
Notifications
You must be signed in to change notification settings - Fork 826
新增local_map API的中文文档 #7245
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
新增local_map API的中文文档 #7245
Changes from 13 commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
07939de
【Hackathon 7th No.19】对应开发的API中文文档
zty-king e231444
【Hackathon 7th No.19】对应开发的API中文文档
zty-king ace6dcf
【Hackathon 7th No.19】对应开发的API中文文档
zty-king 2a91133
【Hackathon 7th No.19】对应开发的API中文文档
zty-king 53e3952
Merge branch 'develop' of https://github.com/PaddlePaddle/docs into my
zty-king ce022b8
【Hackathon 7th No.19】对应开发的API中文文档
zty-king 8e4ef95
Merge branch 'develop' of https://github.com/PaddlePaddle/docs into my
zty-king fbbb4f7
提交local_map的中文文档
zty-king 5079591
提交local_map的中文文档
zty-king 57cc80f
提交local_map的中文文档
zty-king ffcff3d
提交local_map的中文文档
zty-king e4aecb3
修改文档格式
zty-king 8b138fb
修改文档格式
zty-king 1c5ab83
修改文档格式
zty-king 2f20eed
修改文档格式
zty-king f3811b7
修改文档格式
zty-king cfc270a
修改格式
zty-king 2351345
添加总览
zty-king 1b9d318
修改文档格式
zty-king dd77184
修改文档格式
zty-king File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
.. _cn_api_paddle_distributed_local_map: | ||
|
||
local_map | ||
------------------------------- | ||
|
||
.. py:function:: paddle.distributed.local_map(func, out_placements, in_placements=None, process_mesh=None, reshard_inputs=False) | ||
|
||
local_map 是一个函数装饰器,用于在分布式训练中实现局部计算操作。它允许用户将分布式张量传递给为普通张量编写的函数,通过自动处理张量转换,使得用户可以像编写单卡代码一样实现这些局部操作。 | ||
|
||
|
||
参数 | ||
::::::::: | ||
|
||
- **func** (Callable) - 要应用于分布式张量本地分片的函数 | ||
- **out_placements** (list[list[dist.Placement]]) - 指定输出张量的分布策略。每个元素是一个 Placement 列表,描述对应输出张量的分布方式。对于不具有分布式属性的输出应设为 None | ||
- **in_placements** (list[list[dist.Placement]] | None) - 指定输入张量的要求分布。每个元素是一个 Placement 列表,描述对应输入张量的分布要求.对于不具有分布式属性的输入应设为 None,默认为 None | ||
- **process_mesh** (Optional[ProcessMesh]) - 计算设备网格。如未指定则从输入张量推断 | ||
- **reshard_inputs** (bool) - 当输入张量分布不符合要求时是否自动重分布。默认 False | ||
|
||
代码示例 | ||
::::::::: | ||
|
||
.. code-block:: python | ||
|
||
import paddle | ||
import paddle.distributed as dist | ||
from paddle import Tensor | ||
from paddle.distributed import ProcessMesh | ||
|
||
def custom_function(x): | ||
mask = paddle.zeros_like(x) | ||
if dist.get_rank() == 0: | ||
mask[1:3] = 1 | ||
else: | ||
mask[4:7] = 1 | ||
x = x * mask | ||
mask_sum = paddle.sum(x) | ||
mask_sum = mask_sum / mask.sum() | ||
return mask_sum | ||
|
||
dist.init_parallel_env() | ||
mesh = ProcessMesh([0, 1], dim_names=["x"]) | ||
|
||
local_input = paddle.arange(0, 10, dtype='float32') | ||
local_input = local_input + dist.get_rank() | ||
|
||
input_dist = dist.auto_parallel.api.dtensor_from_local( | ||
local_input, | ||
mesh, | ||
[dist.Shard(0)] | ||
) | ||
|
||
# 使用 local_map 包装函数 | ||
wrapped_func = dist.local_map( | ||
custom_function, | ||
out_placements=[[dist.Partial(dist.ReduceType.kRedSum)]], | ||
in_placements=[[dist.Shard(0)]], | ||
process_mesh=mesh | ||
) | ||
|
||
# 应用函数到分布式张量 | ||
output_dist = wrapped_func(input_dist) | ||
|
||
# 收集并打印结果 | ||
local_value = output_dist._local_value() | ||
gathered_values: list[Tensor] = [] | ||
dist.all_gather(gathered_values, local_value) | ||
print(f"[Rank 0] local_value={gathered_values[0].item()}") | ||
# [Rank 0] local_value=1.5 | ||
print(f"[Rank 1] local_value={gathered_values[1].item()}") | ||
# [Rank 1] local_value=6.0 | ||
print(f"global_value (distributed)={output_dist.item()}") | ||
# global_value (distributed)=7.5 | ||
|
||
|
||
方法 | ||
::::::::: | ||
|
||
wrapped_func() | ||
''''''''' | ||
|
||
包装后返回的函数。该函数会: | ||
|
||
1. 将输入的分布式张量转换为本地张量 | ||
2. 在本地执行函数计算 | ||
3. 将计算结果按照指定的分布策略转换回分布式张量 | ||
|
||
**参数** | ||
|
||
- **args** (Any) - 位置参数,通常包含分布式张量 | ||
- **kwargs** (Any) - 额外的关键字参数 | ||
|
||
**返回** | ||
|
||
按照 out_placements 指定的分布策略转换后的分布式张量 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
|
||
**使用场景** | ||
|
||
local_map 可以用于但不限于以下场景: | ||
|
||
1. 带 mask 的 loss 计算:需要在每张卡上独立计算 masked token 的 loss | ||
2. MoE (混合专家模型)相关计算: | ||
- aux_loss 计算:基于每张卡上专家分配到的局部 token 数进行计算 | ||
- z_loss 计算:对每张卡上的 logits 独立计算 z_loss | ||
- 张量 reshape 操作:在局部维度上进行 shape 变换 | ||
3. 需要对分布式张量应用普通张量函数的场景 | ||
4. 需要混合处理分布式张量和普通张量的场景 | ||
|
||
**注意事项** | ||
|
||
1. 输出必须指定正确的分布策略以确保结果正确性 | ||
2. 在函数中可以像单卡编程一样使用常规的 tensor 操作 | ||
3. 计算结果会自动根据分布策略进行聚合,无需手动添加通信操作 | ||
4. 当指定 in_placements 时,输入张量的分布必须匹配要求,除非启用 reshard_inputs | ||
5. 所有分布式张量必须在同一个 process_mesh 上 |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
中文示例代码统一用
COPY-FROM: paddle.xxx
的形式,xxx 为 API 的调用路径,目的为和英文代码保持一致There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done