Description
🐛 Bug
I am trying to run the example codes for distributed training but all my attempts hangs or return an error
To Reproduce
I have tried with test_train_mp_mnist and with the example in PyTorch docs https://pytorch.org/xla/master/learn/pjrt.html#tl-dr
Right now I am trying with the last one because it's simpler
import os
import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.optim as optim
import torch.distributed as dist
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_backend
import torch_xla.runtime as xr
def _mp_fn(index):
device = xm.xla_device()
dist.init_process_group('xla', init_method='xla://')
torch.manual_seed(42)
model = nn.Linear(128, 10).to(device)
# Optional for TPU v4 and GPU
xm.broadcast_master_param(model)
model = DDP(model)
loss_fn = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=.001)
for i in range(10):
data, target = torch.randn((128, 128), device=device), torch.randn((128, 10), device=device)
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
xm.mark_step()
# Print mean parameters so we can confirm they're the same across replicas
print([p.mean() for p in model.parameters()])
if __name__ == '__main__':
torch_xla.launch(_mp_fn)
Steps to reproduce the behavior:
For this example I am using a single machine with 2 GPUs.
When I try to run with 1 GPU to see if the command is correct it gives me the following error:
PJRT_DEVICE=CUDA GPU_NUM_DEVICES=1 torchrun --nnodes=1 --nproc-per-node=1 example-xla.py --epochs 1
(venv) ~$ PJRT_DEVICE=CUDA GPU_NUM_DEVICES=1 torchrun --nnodes=1 --nproc-per-node=1 example-xla.py --epochs 1
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1736507062.028835 2365508 service.cc:148] XLA service 0x5570d36760c0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1736507062.028872 2365508 service.cc:156] StreamExecutor device (0): NVIDIA A40, Compute Capability 8.6
I0000 00:00:1736507062.028877 2365508 service.cc:156] StreamExecutor device (1): NVIDIA A40, Compute Capability 8.6
I0000 00:00:1736507062.030735 2365508 se_gpu_pjrt_client.cc:943] Using BFC allocator.
I0000 00:00:1736507062.030773 2365508 gpu_helpers.cc:114] XLA backend allocating 35802464256 bytes on device 0 for BFCAllocator.
I0000 00:00:1736507062.030796 2365508 gpu_helpers.cc:114] XLA backend allocating 35802464256 bytes on device 1 for BFCAllocator.
I0000 00:00:1736507062.030817 2365508 gpu_helpers.cc:154] XLA backend will use up to 11934154752 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1736507062.030828 2365508 gpu_helpers.cc:154] XLA backend will use up to 11934154752 bytes on device 1 for CollectiveBFCAllocator.
2025-01-10 12:04:22.505221: E external/xla/xla/status_macros.cc:56] INTERNAL: RET_CHECK failure (external/xla/xla/service/hlo_verifier.cc:477) subgroup_size == 1 || shard_count == subgroup_size shard_count = 1, subgroup_size = 2, %all-gather.6 = s64[1]{0} all-gather(s64[1]{0} %add.5), replica_groups={}, dimensions={0}
*** Begin stack trace ***
tsl::CurrentStackTrace()
xla::status_macros::MakeErrorStream::Impl::GetStatus()
xla::ShapeVerifier::HandleAllGather(xla::HloInstruction*)
absl::lts_20230802::Status xla::HloInstruction::Visit<xla::HloInstruction*>(xla::DfsHloVisitorBase<xla::HloInstruction*>*)
absl::lts_20230802::Status xla::HloInstruction::Accept<xla::HloInstruction*>(xla::DfsHloVisitorBase<xla::HloInstruction*>*, bool, bool, bool)
absl::lts_20230802::Status xla::HloComputation::Accept<xla::HloInstruction*>(xla::DfsHloVisitorBase<xla::HloInstruction*>*) const
xla::HloVerifier::Run(xla::HloModule*, absl::lts_20230802::flat_hash_set<std::basic_string_view<char, std::char_traits<char> >, absl::lts_20230802::container_internal::StringHash, absl::lts_20230802::container_internal::StringEq, std::allocator<std::basic_string_view<char, std::char_traits<char> > > > const&)
xla::CreateModuleFromProto(xla::HloModuleProto const&, xla::HloModuleConfig const&, bool)
xla::Service::BuildExecutable(xla::HloModuleProto const&, std::unique_ptr<xla::HloModuleConfig, std::default_delete<xla::HloModuleConfig> >, xla::Backend*, stream_executor::StreamExecutor*, xla::Compiler::CompileOptions const&, bool)
xla::LocalService::CompileExecutables(xla::XlaComputation const&, absl::lts_20230802::Span<xla::Shape const* const>, xla::ExecutableBuildOptions const&)
xla::LocalClient::Compile(xla::XlaComputation const&, absl::lts_20230802::Span<xla::Shape const* const>, xla::ExecutableBuildOptions const&)
xla::PjRtStreamExecutorClient::CompileInternal(xla::XlaComputation const&, std::vector<xla::Shape const*, std::allocator<xla::Shape const*> > const&, std::function<absl::lts_20230802::StatusOr<std::pair<std::vector<xla::Shape, std::allocator<xla::Shape> >, xla::Shape> > (xla::HloModule const&)>, xla::CompileOptions)
xla::PjRtStreamExecutorClient::Compile(xla::XlaComputation const&, xla::CompileOptions)
xla::StreamExecutorGpuClient::Compile(xla::XlaComputation const&, xla::CompileOptions)
torch_xla::runtime::PjRtComputationClient::Compile(std::vector<torch_xla::runtime::ComputationClient::CompileInstance, std::allocator<torch_xla::runtime::ComputationClient::CompileInstance> >)
torch_xla::XLAGraphExecutor::Compile(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >&, absl::lts_20230802::Span<std::string const>, torch::lazy::LazyGraphExecutor::SyncTensorCollection const&, torch::lazy::LazyGraphExecutor::PostOrderData*, std::vector<torch::lazy::Value, std::allocator<torch::lazy::Value> > const&)
torch_xla::XLAGraphExecutor::SyncTensorsGraphInternal(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >*, absl::lts_20230802::Span<std::string const>, torch::lazy::LazyGraphExecutor::SyncTensorsConfig const&, bool)
torch_xla::XLAGraphExecutor::GetTensors(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >*)
torch_xla::bridge::XlaCreateTensorList(c10::IListRef<at::Tensor> const&)
torch_xla::XLANativeFunctions::_to_cpu(c10::ArrayRef<at::Tensor>)
at::_ops::_to_cpu::call(c10::ArrayRef<at::Tensor>)
at::native::cpu_fallback(c10::OperatorHandle const&, std::vector<c10::IValue, std::allocator<c10::IValue> >*, bool, c10::DispatchKey)
torch_xla::xla_fallback(c10::OperatorHandle const&, std::vector<c10::IValue, std::allocator<c10::IValue> >*)
at::native::_call_fallback_fn<&torch_xla::xla_fallback, at::_ops::_local_scalar_dense, false, c10::Scalar (at::Tensor const&)>::call(at::Tensor const&)
torch_xla::XLANativeFunctions::_local_scalar_dense(at::Tensor const&)
c10::Dispatcher::callBoxed(c10::OperatorHandle const&, std::vector<c10::IValue, std::allocator<c10::IValue> >*) const
at::_ops::_local_scalar_dense::redispatch(c10::DispatchKeySet, at::Tensor const&)
at::_ops::_local_scalar_dense::call(at::Tensor const&)
at::native::item(at::Tensor const&)
at::_ops::item::call(at::Tensor const&)
int at::Tensor::item<int>() const
c10d::verify_params_across_processes(c10::intrusive_ptr<c10d::ProcessGroup, c10::detail::intrusive_target_default_null_type<c10d::ProcessGroup> > const&, std::vector<at::Tensor, std::allocator<at::Tensor> > const&, std::optional<std::weak_ptr<c10d::Logger> > const&)
_PyObject_MakeTpCall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyObject_FastCallDictTstate
_PyObject_MakeTpCall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
_PyFunction_Vectorcall
_PyEval_EvalFrameDefault
PyEval_EvalCode
_PyRun_SimpleFileObject
_PyRun_AnyFileObject
Py_RunMain
Py_BytesMain
__libc_start_main
_start
*** End stack trace ***
[rank0]: Traceback (most recent call last):
[rank0]: File "example-xla.py", line 44, in <module>
[rank0]: torch_xla.launch(_mp_fn)
[rank0]: File "xla/venv/lib/python3.10/site-packages/torch_xla/torch_xla.py", line 231, in launch
[rank0]: fn(xu.getenv_as(xenv.LOCAL_RANK, int), *args)
[rank0]: File "xla/unet/ddp/example-xla.py", line 24, in _mp_fn
[rank0]: model = DDP(model)
[rank0]: File "xla/venv/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 825, in __init__
[rank0]: _verify_param_shape_across_processes(self.process_group, parameters)
[rank0]: File "xla/venv/lib/python3.10/site-packages/torch/distributed/utils.py", line 288, in _verify_param_shape_across_processes
[rank0]: return dist._verify_params_across_processes(process_group, tensors, logger)
[rank0]: RuntimeError: Bad StatusOr access: INTERNAL: during context [Unknown]: RET_CHECK failure (external/xla/xla/service/hlo_verifier.cc:477) subgroup_size == 1 || shard_count == subgroup_size shard_count = 1, subgroup_size = 2, %all-gather.6 = s64[1]{0} all-gather(s64[1]{0} %add.5), replica_groups={}, dimensions={0}
E0110 12:04:23.262000 2365422 torch/distributed/elastic/multiprocessing/api.py:869] failed (exitcode: 1) local_rank: 0 (pid: 2365508) of binary: xla/venv/bin/python
Traceback (most recent call last):
File "xla/venv/bin/torchrun", line 8, in <module>
sys.exit(main())
File "xla/venv/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
return f(*args, **kwargs)
File "xla/venv/lib/python3.10/site-packages/torch/distributed/run.py", line 919, in main
run(args)
File "xla/venv/lib/python3.10/site-packages/torch/distributed/run.py", line 910, in run
elastic_launch(
File "xla/venv/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 138, in __call__
return launch_agent(self._config, self._entrypoint, list(args))
File "xla/venv/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 269, in launch_agent
raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
On the other side, if I tried to run with both GPUs to have parallelism, it just hangs. When I cancel the execution the traceback is this, it appears to hang on torchxla.launch
torchrun --nnodes=1 --nproc-per-node=2 example-xla.py --epochs 1
Environment
$> pip freeze | grep torch
torch==2.5.1
torch-xla @ https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.0-cp310-cp310-manylinux_2_28_x86_64.whl#sha256=86d0a9af00fb678f903e5c4968e30dca3c50d6ce64aa33da9314b2134418ace3
torchvision==0.20.1
CUDA 12.5
Driver 555.42.06
Additional context
I can successfully execute other non-parallel xla scripts.
When I try to use multiple nodes with torchrun it also hangs, while the same command with non-xla scripts works perfectly.