Open
Description
🐛 Bug
Torch XLA Model all_gather works with tensors of same size along dim=0
, but if tensor sizes are different along dim=0
, it hangs.
To Reproduce
Save this code in test_all_gather.py
import torch
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla.distributed.xla_backend as xb
import torch.distributed
def test_all_gather():
same = [512, 512, 512, 512, 512, 512, 512, 512]
different = [416, 536, 560, 544, 576, 512, 592, 360]
torch.distributed.init_process_group(backend="xla", init_method="xla://")
rank = torch.distributed.get_rank()
device = xm.xla_device()
input = torch.randn((same[rank], 16), dtype=torch.float32, device=device)
all_inputs = xm.all_gather(input, dim=0, groups=[[0,1,2,3,4,5,6,7]], pin_layout=False)
print(f"!!!!!! rank: {rank}, all_inputs: {all_inputs}")
input = torch.randn((different[rank], 16), dtype=torch.float32, device=device)
all_inputs = xm.all_gather(input, dim=0, groups=[[0,1,2,3,4,5,6,7]], pin_layout=False)
print(f"!!!!!! rank: {rank}, all_inputs: {all_inputs}")
torch.distributed.destroy_process_group()
if __name__ == "__main__":
test_all_gather()
torchrun --nproc_per_node=8 test_all_gather.py
Expected behavior
It should gather all the tensors from all the devices along dim=0
Environment
Docker image
us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.0_3.10_cuda_12.4
Additional context
According to this documentation for all_gather
https://pytorch.org/docs/stable/distributed.html uneven tensor sizes are supported.