Skip to content

新增local_map API的中文文档 #7245

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
May 12, 2025
Merged
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions docs/api/paddle/distributed/local_map_cn.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
.. _cn_api_paddle_distributed_local_map:

local_map
-------------------------------

.. py:function:: paddle.distributed.local_map(func, out_placements, in_placements=None, process_mesh=None, reshard_inputs=False)

local_map 是一个函数装饰器,用于在分布式训练中实现局部计算操作。它允许用户将分布式张量传递给为普通张量编写的函数,通过自动处理张量转换,使得用户可以像编写单卡代码一样实现这些局部操作。


参数
:::::::::

- **func** (Callable) - 要应用于分布式张量本地分片的函数
- **out_placements** (list[list[dist.Placement]]) - 指定输出张量的分布策略。每个元素是一个 Placement 列表,描述对应输出张量的分布方式。对于不具有分布式属性的输出应设为 None
- **in_placements** (list[list[dist.Placement]] | None) - 指定输入张量的要求分布。每个元素是一个 Placement 列表,描述对应输入张量的分布要求.对于不具有分布式属性的输入应设为 None,默认为 None
- **process_mesh** (Optional[ProcessMesh]) - 计算设备网格。如未指定则从输入张量推断
- **reshard_inputs** (bool) - 当输入张量分布不符合要求时是否自动重分布。默认 False

代码示例
:::::::::

.. code-block:: python
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

中文示例代码统一用 COPY-FROM: paddle.xxx 的形式,xxx 为 API 的调用路径,目的为和英文代码保持一致

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


import paddle
import paddle.distributed as dist
from paddle import Tensor
from paddle.distributed import ProcessMesh

def custom_function(x):
mask = paddle.zeros_like(x)
if dist.get_rank() == 0:
mask[1:3] = 1
else:
mask[4:7] = 1
x = x * mask
mask_sum = paddle.sum(x)
mask_sum = mask_sum / mask.sum()
return mask_sum

dist.init_parallel_env()
mesh = ProcessMesh([0, 1], dim_names=["x"])

local_input = paddle.arange(0, 10, dtype='float32')
local_input = local_input + dist.get_rank()

input_dist = dist.auto_parallel.api.dtensor_from_local(
local_input,
mesh,
[dist.Shard(0)]
)

# 使用 local_map 包装函数
wrapped_func = dist.local_map(
custom_function,
out_placements=[[dist.Partial(dist.ReduceType.kRedSum)]],
in_placements=[[dist.Shard(0)]],
process_mesh=mesh
)

# 应用函数到分布式张量
output_dist = wrapped_func(input_dist)

# 收集并打印结果
local_value = output_dist._local_value()
gathered_values: list[Tensor] = []
dist.all_gather(gathered_values, local_value)
print(f"[Rank 0] local_value={gathered_values[0].item()}")
# [Rank 0] local_value=1.5
print(f"[Rank 1] local_value={gathered_values[1].item()}")
# [Rank 1] local_value=6.0
print(f"global_value (distributed)={output_dist.item()}")
# global_value (distributed)=7.5


方法
:::::::::

wrapped_func()
'''''''''

包装后返回的函数。该函数会:

1. 将输入的分布式张量转换为本地张量
2. 在本地执行函数计算
3. 将计算结果按照指定的分布策略转换回分布式张量

**参数**

- **args** (Any) - 位置参数,通常包含分布式张量
- **kwargs** (Any) - 额外的关键字参数

**返回**

按照 out_placements 指定的分布策略转换后的分布式张量
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

local_map is a API of function that does not need to describe methods, only API of class need to describe member methods in the class

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


**使用场景**

local_map 可以用于但不限于以下场景:

1. 带 mask 的 loss 计算:需要在每张卡上独立计算 masked token 的 loss
2. MoE (混合专家模型)相关计算:
- aux_loss 计算:基于每张卡上专家分配到的局部 token 数进行计算
- z_loss 计算:对每张卡上的 logits 独立计算 z_loss
- 张量 reshape 操作:在局部维度上进行 shape 变换
3. 需要对分布式张量应用普通张量函数的场景
4. 需要混合处理分布式张量和普通张量的场景

**注意事项**

1. 输出必须指定正确的分布策略以确保结果正确性
2. 在函数中可以像单卡编程一样使用常规的 tensor 操作
3. 计算结果会自动根据分布策略进行聚合,无需手动添加通信操作
4. 当指定 in_placements 时,输入张量的分布必须匹配要求,除非启用 reshard_inputs
5. 所有分布式张量必须在同一个 process_mesh 上