Skip to content

Commit 201d2c4

Browse files
authored
[cherrypick] add dist op doc (#8273) (#8274)
1 parent f1c8626 commit 201d2c4

File tree

2 files changed

+101
-0
lines changed

2 files changed

+101
-0
lines changed

docs/_static/img/dist_op_stack.png

80.5 KB
Loading

docs/distop.md

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Support of Torch Distributed API in PyTorch/XLA
2+
Before the 2.5 release, PyTorch/XLA only supported collective ops through our custom `torch_xla.core.xla_model.*` API. In the 2.5 release, we adopt `torch.distributed.*` in PyTorch/XLA for both dynamo and non-dynamo cases.
3+
4+
## Collective ops lowering
5+
### Collective ops lowering stack
6+
After introducing the [traceable collective communication APIs](https://github.com/pytorch/pytorch/issues/93173), dynamo can support the collective ops with reimplementing lowering in PyTorch/XLA. Collective ops are traceable through torch.ops._c10d_functional call. The following figure shows how a collective op, `all_reduce` in this case, is lowered between torch and torch_xla:
7+
8+
9+
<img src="_static/img/dist_op_stack.png" alt="Alt Text" width="500" height="400">
10+
11+
_<span style="text-decoration:underline;">Figure 1. Collective ops lowering stack</span>_
12+
13+
### Non-dynamo case
14+
Collective ops are lowered by registering the `ProcessGroupXla`, which is derived from PyTorch `ProcessGroup`:
15+
16+
```Python
17+
# torch_xla/distributed/xla_backend.py
18+
def _create_xla_process_group(prefix_store, rank, size, timeout):
19+
assert not xr.is_spmd(
20+
), "XLA backend is not supported with SPMD. Please use a CPU process group instead."
21+
return ProcessGroupXla(prefix_store, rank, size, timeout)
22+
23+
24+
def _register_xla_backend():
25+
dist.Backend.register_backend('xla', _create_xla_process_group, devices='xla')
26+
27+
28+
class ProcessGroupXla(ProcessGroup):
29+
...
30+
def allreduce(self, tensors, all_reduce_options):
31+
...
32+
def allgather(self, output_tensors_list, input_tensors, opts=None):
33+
...
34+
```
35+
36+
The corresponding xla dist backend is initialized when we enter multiprocess function call:
37+
```Python
38+
def _mp_fn(rank):
39+
dist.init_process_group("xla", init_method='xla://')
40+
41+
With `dist.init_process_group`, collective ops will be called based on the progress group instance:
42+
43+
# E.g., pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py
44+
@_exception_logger
45+
def all_gather(tensor_list, tensor, group=None, async_op=False):
46+
...
47+
group = group or _get_default_group()
48+
work = group.allgather([tensor_list], [tensor]) # uses ProcessGroupXla.allgather instead
49+
```
50+
51+
### Dynamo case
52+
For dynamo case, certain collective ops are remapped to the new function in [pytorch/torch/distributed/_functional_collectives.py](https://github.com/pytorch/pytorch/blob/v2.5.0-rc10/torch/distributed/_functional_collectives.py#L1129-L1150). For example, `all_reduce()` will be mapped to `all_reduce_inplace()`, where eventually `torch.ops._c10d_functional.all_reduce()`. Once we reach the _c10d_functional, we can rewrite the op through PyTorch/Xla lowering:
53+
54+
55+
```C++
56+
at::Tensor all_reduce(const at::Tensor& self, std::string reduceOp,
57+
std::string /*group_name*/) {...}
58+
59+
TORCH_LIBRARY_IMPL(_c10d_functional, XLA, m) {
60+
m.impl("all_reduce", all_reduce);
61+
}
62+
```
63+
64+
65+
## API description
66+
67+
For release 2.5, we now support four collective operations for both dynamo and non-dynamo cases. Our goal is to align the distributed operation (dist op) APIs with PyTorch's upstream implementation. One thing to note is that distributed collective ops will not work with the GSPMD, where the collective ops is automatically injected in the XLA compiler level. While distributed function signatures remain consistent, certain input restrictions still apply. For instance, specifying multiple groups for distributed collective operations is not yet supported. For usage examples, refer to [test_collective_ops_tpu.py](https://github.com/pytorch/xla/blob/v2.5.0-rc10/test/pjrt/test_collective_ops_tpu.py), which demonstrates the use of collective ops in both dynamo and non-dynamo scenarios.
68+
To use the distributed ops, we need to first call `dist.init_process_group` in the multiprocess function:
69+
```Python
70+
import torch.distributed as dist
71+
import torch_xla
72+
def _mp_fn(rank):
73+
dist.init_process_group("xla", init_method='xla://')
74+
...
75+
76+
if __name__ == '__main__':
77+
torch_xla.launch(_mp_fn)
78+
79+
```
80+
Below are the details for collective operation functions:
81+
```Python
82+
dist.all_reduce(input: torch.Tensor, op: dist.ReduceOp = ReduceOp.SUM)
83+
```
84+
`all_reduce` performs an in-place reduction on the `input` tensor by aggregating data from all nodes.
85+
86+
```Python
87+
dist.all_gather_into_tensor(output, input)
88+
```
89+
`all_gather_into_tensor` gathers the input tensor from all nodes and updates the `output` tensor in-place. It also returns an alias of the output.
90+
91+
```Python
92+
dist.reduce_scatter_tensor(output, input, op: dist.ReduceOp = ReduceOp.SUM)
93+
```
94+
`reduce_scatter_tensor` reduces the input tensor across all nodes and distributes the result to the `output` tensor in-place. It returns an alias of the output.
95+
96+
```Python
97+
dist.all_to_all_single(output, input, output_split_sizes=None, input_split_sizes=None)
98+
```
99+
`all_to_all_single` function performs an all-to-all communication, updating the output tensor in-place and returning its alias.
100+
101+
Note: Although `output_split_sizes` and `input_split_sizes` are accepted as arguments, they must be either None or set to all 1s. This limitation reflects a compromise between maintaining PyTorch’s API signature and the constraints of the XLA AllToAll operation.

0 commit comments

Comments
 (0)