Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

R2.5 #8285

Closed
wants to merge 18 commits into from
Closed

R2.5 #8285

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
87304aa
Add torch_pin to 2.5 release branch
ManfeiBai Sep 9, 2024
ba13b4e
Follow 2 releases grace period for API deprecation (#8003)
zpcore Sep 12, 2024
3c7daa2
Set xla_tpu_enable_flash_attention=false to enable libtpu pin update …
bhavya01 Sep 13, 2024
1717e25
Use Tuple instead of tuple in scan.py (#8067)
tengyifei Sep 25, 2024
8f43eef
cherry pick collective op all_to_all_single (#8077)
zpcore Sep 27, 2024
5cc01cf
Force github workflow catch PyTorch commit from release/2.5 branch fo…
ManfeiBai Sep 27, 2024
3bb8736
[cherry-pick] Update Openxla-pin to Sep13, libtpu-pin to Sep13, jax t…
ManfeiBai Sep 27, 2024
31a0a92
[Cherry-pick] Enable cross entropy loss for xla autocast with FP32 pr…
ManfeiBai Oct 2, 2024
c074e31
change r2.5 to use jax stable version 0.4.33 (#8062)
ManfeiBai Oct 3, 2024
8ff43a9
Allow MpDeviceLoader to shard dictionaries of tensor for 2.5 release …
bhavya01 Oct 4, 2024
ee0f159
Update cuda_deps.yaml for r2.5 (#8219)
ManfeiBai Oct 4, 2024
fbede65
Part 1: Introduce multi-node SPMD support for Neuron (#8204) (#8224)
jeffhataws Oct 7, 2024
f1c8626
make 2.5 whl pypi compatible (#8268)
lsy323 Oct 17, 2024
201d2c4
[cherrypick] add dist op doc (#8273) (#8274)
zpcore Oct 17, 2024
4df24fd
Update collective op doc based on feedback (#8277)
zpcore Oct 18, 2024
01c23a4
Add torch_pin to 2.5 release branch (#8279)
JackCaoG Oct 18, 2024
2992ae3
Add torch_pin to 2.5 release branch (#8280)
JackCaoG Oct 18, 2024
396608c
[cherry-pick] Update README.md (#8281) (#8283)
ManfeiBai Oct 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
- id: commit
name: Get latest torch commit
run: |
echo "torch_commit=$(git ls-remote https://github.com/pytorch/pytorch.git HEAD | awk '{print $1}')" >> "$GITHUB_OUTPUT"
echo "torch_commit=$(git ls-remote https://github.com/pytorch/pytorch.git refs/heads/release/2.5 | awk '{print $1}')" >> "$GITHUB_OUTPUT"

build-torch-xla:
name: "Build PyTorch/XLA"
Expand Down
1 change: 1 addition & 0 deletions .torch_pin
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
release/2.5
34 changes: 23 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ started:
To install PyTorch/XLA stable build in a new TPU VM:

```
pip install torch~=2.4.0 torch_xla[tpu]~=2.4.0 -f https://storage.googleapis.com/libtpu-releases/index.html
pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 -f https://storage.googleapis.com/libtpu-releases/index.html
```

To install PyTorch/XLA nightly build in a new TPU VM:
Expand All @@ -41,7 +41,7 @@ pip install 'torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-release
PyTorch/XLA now provides GPU support through a plugin package similar to `libtpu`:

```
pip install torch~=2.4.0 torch_xla~=2.4.0 https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla_cuda_plugin-2.4.0-py3-none-any.whl
pip install torch~=2.5.0 torch_xla~=2.5.0 https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla_cuda_plugin-2.5.0-py3-none-any.whl
```

## Getting Started
Expand Down Expand Up @@ -154,15 +154,14 @@ GPU and nightly builds are available in our public GCS bucket.

| Version | Cloud GPU VM Wheels |
| --- | ----------- |
| 2.4 (CUDA 12.1 + Python 3.9) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.4.0-cp39-cp39-linux_x86_64.whl` |
| 2.4 (CUDA 12.1 + Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.4.0-cp310-cp310-linux_x86_64.whl` |
| 2.4 (CUDA 12.1 + Python 3.11) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.4.0-cp311-cp311-linux_x86_64.whl` |
| 2.4 (CUDA 12.4 + Python 3.9) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.4.0-cp39-cp39-linux_x86_64.whl` |
| 2.4 (CUDA 12.4 + Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.4.0-cp310-cp310-linux_x86_64.whl` |
| 2.4 (CUDA 12.4 + Python 3.11) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.4.0-cp311-cp311-linux_x86_64.whl` |
| nightly (Python 3.8) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.5.0.dev-cp38-cp38-linux_x86_64.whl` |
| nightly (Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.5.0.dev-cp310-cp310-linux_x86_64.whl` |
| nightly (CUDA 12.1 + Python 3.8) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.0.dev-cp38-cp38-linux_x86_64.whl` |
| 2.5 (CUDA 12.1 + Python 3.9) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.0-cp39-cp39-manylinux_2_28_x86_64.whl` |
| 2.5 (CUDA 12.1 + Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.0-cp310-cp310-manylinux_2_28_x86_64.whl` |
| 2.5 (CUDA 12.1 + Python 3.11) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.0-cp311-cp311-manylinux_2_28_x86_64.whl` |
| 2.5 (CUDA 12.4 + Python 3.9) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.0-cp39-cp39-manylinux_2_28_x86_64.whl` |
| 2.5 (CUDA 12.4 + Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.0-cp310-cp310-manylinux_2_28_x86_64.whl` |
| 2.5 (CUDA 12.4 + Python 3.11) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.0-cp311-cp311-manylinux_2_28_x86_64.whl` |
| nightly (Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev-cp310-cp310-linux_x86_64.whl` |
| nightly (CUDA 12.1 + Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.6.0.dev-cp310-cp310-linux_x86_64.whl` |

<details>

Expand Down Expand Up @@ -194,6 +193,7 @@ The torch wheel version `2.5.0.dev20240820+cpu` can be found at https://download

| Version | Cloud TPU VMs Wheel |
|---------|-------------------|
| 2.4 (Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.4.0-cp310-cp310-manylinux_2_28_x86_64.whl` |
| 2.3 (Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.3.0-cp310-cp310-manylinux_2_28_x86_64.whl` |
| 2.2 (Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.2.0-cp310-cp310-manylinux_2_28_x86_64.whl` |
| 2.1 (XRT + Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/xrt/tpuvm/torch_xla-2.1.0%2Bxrt-cp310-cp310-manylinux_2_28_x86_64.whl` |
Expand All @@ -203,6 +203,15 @@ The torch wheel version `2.5.0.dev20240820+cpu` can be found at https://download

| Version | GPU Wheel |
| --- | ----------- |
| 2.5 (CUDA 12.1 + Python 3.9) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.0-cp39-cp39-manylinux_2_28_x86_64.whl` |
| 2.5 (CUDA 12.1 + Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.0-cp310-cp310-manylinux_2_28_x86_64.whl` |
| 2.5 (CUDA 12.1 + Python 3.11) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.0-cp311-cp311-manylinux_2_28_x86_64.whl` |
| 2.5 (CUDA 12.4 + Python 3.9) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.0-cp39-cp39-manylinux_2_28_x86_64.whl` |
| 2.5 (CUDA 12.4 + Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.0-cp310-cp310-manylinux_2_28_x86_64.whl` |
| 2.5 (CUDA 12.4 + Python 3.11) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.0-cp311-cp311-manylinux_2_28_x86_64.whl` |
| 2.4 (CUDA 12.1 + Python 3.9) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.4.0-cp39-cp39-manylinux_2_28_x86_64.whl` |
| 2.4 (CUDA 12.1 + Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.4.0-cp310-cp310-manylinux_2_28_x86_64.whl` |
| 2.4 (CUDA 12.1 + Python 3.11) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.4.0-cp311-cp311-manylinux_2_28_x86_64.whl` |
| 2.3 (CUDA 12.1 + Python 3.8) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.3.0-cp38-cp38-manylinux_2_28_x86_64.whl` |
| 2.3 (CUDA 12.1 + Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.3.0-cp310-cp310-manylinux_2_28_x86_64.whl` |
| 2.3 (CUDA 12.1 + Python 3.11) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.3.0-cp311-cp311-manylinux_2_28_x86_64.whl` |
Expand All @@ -217,6 +226,7 @@ The torch wheel version `2.5.0.dev20240820+cpu` can be found at https://download

| Version | Cloud TPU VMs Docker |
| --- | ----------- |
| 2.5 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.0_3.10_tpuvm` |
| 2.4 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_3.10_tpuvm` |
| 2.3 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.3.0_3.10_tpuvm` |
| 2.2 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.2.0_3.10_tpuvm` |
Expand All @@ -233,13 +243,15 @@ docker run --privileged --net host --shm-size=16G -it us-central1-docker.pkg.dev

| Version | GPU CUDA 12.4 Docker |
| --- | ----------- |
| 2.5 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.0_3.10_cuda_12.4` |
| 2.4 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_3.10_cuda_12.4` |

<br/>


| Version | GPU CUDA 12.1 Docker |
| --- | ----------- |
| 2.5 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.0_3.10_cuda_12.1` |
| 2.4 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_3.10_cuda_12.1` |
| 2.3 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.3.0_3.10_cuda_12.1` |
| 2.2 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.2.0_3.10_cuda_12.1` |
Expand Down
4 changes: 2 additions & 2 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ new_local_repository(
# curl -L https://github.com/openxla/xla/archive/<git hash>.tar.gz | sha256sum
# and update the sha256 with the result.

xla_hash = 'be7eef5742089e328152908b8662e83e34bf73c1'
xla_hash = '32ebd694c4d0442e241d76324ff1a721831366b4'

http_archive(
name = "xla",
Expand Down Expand Up @@ -139,4 +139,4 @@ xla_workspace0()
load("@tsl//third_party/gpus:cuda_configure.bzl", "cuda_configure")
cuda_configure(name = "local_config_cuda")
load("@tsl//third_party/nccl:nccl_configure.bzl", "nccl_configure")
nccl_configure(name = "local_config_nccl")
nccl_configure(name = "local_config_nccl")
Binary file added docs/_static/img/dist_op_stack.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
101 changes: 101 additions & 0 deletions docs/distop.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Support of Torch Distributed API in PyTorch/XLA
PyTorch/XLA version 2.5 adopts the `torch.distributed` API. Before version 2.5 PyTorch/XLA only supported collective ops through the custom `torch_xla.core.xla_model.*` API. `torch.distributed.*` works whether or not you are using the `torch._dynamo` API.

## Collective ops lowering
### Collective ops lowering stack
PyTorch/XLA version 2.5 introduces the [traceable collective communication APIs](https://github.com/pytorch/pytorch/issues/93173), which enables Dynamo to support collective ops by reimplementing op lowering. Collective ops are traceable through methods defined in the `torch.ops._c10d_functional` namespace. The following figure shows how an `all_reduce` collective op is lowered between `torch` and `torch_xla`:


<img src="_static/img/dist_op_stack.png" alt="Alt Text" width="500" height="400">

_<span style="text-decoration:underline;">Figure 1. Collective ops lowering stack</span>_

### Non-dynamo collective op lowering
Collective ops are lowered by registering the `ProcessGroupXla` backend:

```Python
# torch_xla/distributed/xla_backend.py
def _create_xla_process_group(prefix_store, rank, size, timeout):
assert not xr.is_spmd(
), "XLA backend is not supported with SPMD. Please use a CPU process group instead."
return ProcessGroupXla(prefix_store, rank, size, timeout)


def _register_xla_backend():
dist.Backend.register_backend('xla', _create_xla_process_group, devices='xla')


class ProcessGroupXla(ProcessGroup):
...
def allreduce(self, tensors, all_reduce_options):
...
def allgather(self, output_tensors_list, input_tensors, opts=None):
...
```

The `ProcessGroupXla` backend is initialized in the multiprocess function call:
```Python
def _mp_fn(rank):
dist.init_process_group("xla", init_method='xla://')

With `dist.init_process_group`, collective ops are called based on the process group instance:

# E.g., pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py
@_exception_logger
def all_gather(tensor_list, tensor, group=None, async_op=False):
...
group = group or _get_default_group()
work = group.allgather([tensor_list], [tensor]) # uses ProcessGroupXla.allgather instead
```

### Dynamo collective op lowering
When you use dynamo, certain collective ops are remapped to a new function in [pytorch/torch/distributed/_functional_collectives.py](https://github.com/pytorch/pytorch/blob/v2.5.0-rc10/torch/distributed/_functional_collectives.py#L1129-L1150). For example, `all_reduce()` is mapped to `all_reduce_inplace()`, and eventually `torch.ops._c10d_functional.all_reduce()`. Once we reach the _c10d_functional function, we can rewrite the op through PyTorch/XLA lowering:

```C++
at::Tensor all_reduce(const at::Tensor& self, std::string reduceOp,
std::string /*group_name*/) {...}

TORCH_LIBRARY_IMPL(_c10d_functional, XLA, m) {
m.impl("all_reduce", all_reduce);
}
```


## API description

PyTorch 2.5, supports four collective operations for both Dynamo and non-Dynamo cases. Our goal is to align the distributed operation (dist op) APIs with PyTorch's upstream implementation. One thing to note is that distributed collective ops will not work with the GSPMD, where collective ops are automatically injected at the XLA compiler level. While distributed function signatures remain consistent, certain input restrictions still apply. For instance, specifying multiple process groups for distributed collective operations is not yet supported. For usage examples, refer to [test_collective_ops_tpu.py](https://github.com/pytorch/xla/blob/v2.5.0-rc10/test/pjrt/test_collective_ops_tpu.py), which demonstrates the use of collective ops in both Dynamo and non-Dynamo scenarios.
To use the distributed ops, call `dist.init_process_group` in your multiprocess function:

```Python
import torch.distributed as dist
import torch_xla
def _mp_fn(rank):
dist.init_process_group("xla", init_method='xla://')
...

if __name__ == '__main__':
torch_xla.launch(_mp_fn)

```
Below are the details for collective operation functions:
```Python
dist.all_reduce(input: torch.Tensor, op: dist.ReduceOp = ReduceOp.SUM)
```
`all_reduce` performs an in-place reduction on the `input` tensor by aggregating data from all nodes.

```Python
dist.all_gather_into_tensor(output, input)
```
`all_gather_into_tensor` gathers the input tensor from all nodes and updates the `output` tensor in-place. It also returns an alias of the output.

```Python
dist.reduce_scatter_tensor(output, input, op: dist.ReduceOp = ReduceOp.SUM)
```
`reduce_scatter_tensor` reduces the input tensor across all nodes and distributes the result to the `output` tensor in-place. It returns an alias of the output.

```Python
dist.all_to_all_single(output, input, output_split_sizes=None, input_split_sizes=None)
```
`all_to_all_single` function performs an all-to-all communication, updating the output tensor in-place and returning its alias.

Note: Although `output_split_sizes` and `input_split_sizes` are accepted as arguments, they must be either None or set to all 1s. This limitation reflects a compromise between maintaining PyTorch’s API signature and the constraints of the XLA AllToAll operation.
26 changes: 22 additions & 4 deletions docs/spmd_advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,28 @@ PyTorch/XLA SPMD takes a single-device program, shards and executes it in parall
```python
# MpDeviceLoader returns ParallelLoader.per_device_loader as iterator
train_loader = pl.MpDeviceLoader(
train_loader, # wraps PyTorch DataLoader
device,
# assume 4d input and we want to shard at the batch dimension.
input_sharding=xs.ShardingSpec(input_mesh, ('data', None, None, None)))
train_loader, # wraps PyTorch DataLoader
device,
# assume 4d input and we want to shard at the batch dimension.
input_sharding=xs.ShardingSpec(input_mesh, ('data', None, None, None)))
```

It is also possible to specify a different `input_sharding` for each element of the batch if they are different shapes:

```python
# if batch = next(train_loader) looks like
# {'x': <tensor of shape [s1, s2, s3, s4]>, 'y': <tensor for shape [s1, s2]>}

# MpDeviceLoader returns ParallelLoader.per_device_loader as iterator
train_loader = pl.MpDeviceLoader(
train_loader, # wraps PyTorch DataLoader
device,
# specify different sharding for each input of the batch.
input_sharding={
'x': xs.ShardingSpec(input_mesh, ('data', None, None, None)),
'y': xs.ShardingSpec(input_mesh, ('data', None))
}
)
```

### Virtual Device Optimization
Expand Down
4 changes: 2 additions & 2 deletions infra/ansible/config/cuda_deps.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
cuda_deps:
# List all libcudnn8 versions with `apt list -a libcudnn8`
libcudnn:
"12.4": libcudnn-cuda-12=9.1.1.17-1
"12.4": libcudnn9-cuda-12=9.1.1.17-1
"12.3": libcudnn9-cuda-12=9.0.0.312-1
"12.1": libcudnn8=8.9.2.26-1+cuda12.1
"12.0": libcudnn8=8.8.0.121-1+cuda12.0
"11.8": libcudnn8=8.7.0.84-1+cuda11.8
"11.7": libcudnn8=8.5.0.96-1+cuda11.7
"11.2": libcudnn8=8.1.1.33-1+cuda11.2
libcudnn-dev:
"12.4": libcudnn-dev-cuda-12=9.1.1.17-1
"12.4": libcudnn9-dev-cuda-12=9.1.1.17-1
"12.3": libcudnn9-dev-cuda-12=9.0.0.312-1
"12.1": libcudnn8-dev=8.9.2.26-1+cuda12.1
"12.0": libcudnn8-dev=8.8.0.121-1+cuda12.0
Expand Down
2 changes: 1 addition & 1 deletion infra/ansible/roles/build_srcs/tasks/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@

- name: Build PyTorch/XLA
ansible.builtin.command:
cmd: python setup.py bdist_wheel
cmd: python setup.py bdist_wheel -p manylinux_2_28_x86_64
chdir: "{{ (src_root, 'pytorch/xla') | path_join }}"
environment: "{{ env_vars }}"

Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@

base_dir = os.path.dirname(os.path.abspath(__file__))

_date = '20240801'
_date = '20240916'
_libtpu_version = f'0.1.dev{_date}'
_libtpu_storage_path = f'https://storage.googleapis.com/libtpu-nightly-releases/wheels/libtpu-nightly/libtpu_nightly-{_libtpu_version}+nightly-py3-none-any.whl'
_jax_version = f'0.4.32.dev{_date}'
_jax_version = f'0.4.33'


def _get_build_mode():
Expand Down
39 changes: 38 additions & 1 deletion test/pjrt/test_collective_ops_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def test_all_to_all(self, pin_layout):
list(range(world_size))]])


@absltest.skipIf(lambda: tpu.num_logical_cores_per_chip() >= 2,
@absltest.skipIf(tpu.num_logical_cores_per_chip() >= 2,
"Dynamo not supported on TPU v2/v3")
class TestDistCollectiveOpsTpu(parameterized.TestCase):
"""Test for collective ops from torch.distributed"""
Expand Down Expand Up @@ -246,6 +246,32 @@ def callable(output, input):
assert 'xla::reduce_scatter_tensor' in met.counter_names()
return output.cpu()

@staticmethod
def _all_to_all_single(use_dynamo: bool):
met.clear_all()
dist.init_process_group("xla", init_method='xla://')
device = xm.xla_device()

def callable(output, input):
dist.all_to_all_single(output, input)
return output

# check https://github.com/pytorch/pytorch/blob/758d78790164bfb041555daed380de96e06f78a3/torch/distributed/distributed_c10d.py#L3880
# for input and output tensor example
tensor_in = torch.tensor(
[xr.local_ordinal()] * tpu.num_expected_global_devices(),
dtype=torch.float,
device=device)
tensor_out = torch.zeros_like(tensor_in)
f = torch.compile(callable, backend='openxla') if use_dynamo else callable
output = f(tensor_out, tensor_in)
torch_xla.sync()
if not use_dynamo:
assert 'xla::AllToAll' in met.counter_names()
else:
assert 'xla::all_to_all_single' in met.counter_names()
return output.cpu()

@parameterized.named_parameters(('dynamo', True), ('nondynamo', False))
def test_all_reduce(self, use_dynamo):
results = pjrt.run_multiprocess(self._all_reduce, use_dynamo=use_dynamo)
Expand Down Expand Up @@ -287,6 +313,17 @@ def test_reduce_scatter(self, use_dynamo):
for index, val in results.items():
torch.testing.assert_close(val, expected[index])

@parameterized.named_parameters(('dynamo', True), ('nondynamo', False))
def test_all_to_all_single(self, use_dynamo):
results = pjrt.run_multiprocess(
self._all_to_all_single, use_dynamo=use_dynamo)
expected = torch.arange(
tpu.num_expected_global_devices(), dtype=torch.float)
# Note: AllToAll xla op does not honor the order of the all_to_all, which means
# the rank may not follow the order.
for _, val in results.items():
self.assertTrue(torch.allclose(val.sort().values, expected.sort().values))


if __name__ == '__main__':
absltest.main()
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ function run_xla_op_tests3 {
run_test "$CDIR/stablehlo/test_unbounded_dynamism.py"
run_test "$CDIR/quantized_ops/test_quantized_matmul.py"
run_test "$CDIR/quantized_ops/test_dot_general.py"
run_test "$CDIR/spmd/test_mp_input_sharding.py"
run_test "$CDIR/spmd/test_xla_sharding.py"
run_test "$CDIR/spmd/test_xla_sharding_hlo.py"
run_test "$CDIR/spmd/test_xla_virtual_device.py"
Expand Down
Loading
Loading