Skip to content

Multigpu training hangs using single and multiple nodes #8549

Open
@Patataman

Description

🐛 Bug

I am trying to run the example codes for distributed training but all my attempts hangs or return an error

To Reproduce

I have tried with test_train_mp_mnist and with the example in PyTorch docs https://pytorch.org/xla/master/learn/pjrt.html#tl-dr

Right now I am trying with the last one because it's simpler

import os

import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.optim as optim
import torch.distributed as dist
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_backend
import torch_xla.runtime as xr


def _mp_fn(index):
  device = xm.xla_device()
  dist.init_process_group('xla', init_method='xla://')

  torch.manual_seed(42)
  model = nn.Linear(128, 10).to(device)

  # Optional for TPU v4 and GPU
  xm.broadcast_master_param(model)
  model = DDP(model)

  loss_fn = nn.MSELoss()
  optimizer = optim.SGD(model.parameters(), lr=.001)

  for i in range(10):
    data, target = torch.randn((128, 128), device=device), torch.randn((128, 10), device=device)

    optimizer.zero_grad()
    output = model(data)
    loss = loss_fn(output, target)
    loss.backward()

    optimizer.step()
    xm.mark_step()

  # Print mean parameters so we can confirm they're the same across replicas
  print([p.mean() for p in model.parameters()])

if __name__ == '__main__':
  torch_xla.launch(_mp_fn)

Steps to reproduce the behavior:

For this example I am using a single machine with 2 GPUs.

When I try to run with 1 GPU to see if the command is correct it gives me the following error:

  • PJRT_DEVICE=CUDA GPU_NUM_DEVICES=1 torchrun --nnodes=1 --nproc-per-node=1 example-xla.py --epochs 1
(venv) ~$ PJRT_DEVICE=CUDA GPU_NUM_DEVICES=1 torchrun --nnodes=1 --nproc-per-node=1 example-xla.py --epochs 1
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1736507062.028835 2365508 service.cc:148] XLA service 0x5570d36760c0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1736507062.028872 2365508 service.cc:156]   StreamExecutor device (0): NVIDIA A40, Compute Capability 8.6
I0000 00:00:1736507062.028877 2365508 service.cc:156]   StreamExecutor device (1): NVIDIA A40, Compute Capability 8.6
I0000 00:00:1736507062.030735 2365508 se_gpu_pjrt_client.cc:943] Using BFC allocator.
I0000 00:00:1736507062.030773 2365508 gpu_helpers.cc:114] XLA backend allocating 35802464256 bytes on device 0 for BFCAllocator.
I0000 00:00:1736507062.030796 2365508 gpu_helpers.cc:114] XLA backend allocating 35802464256 bytes on device 1 for BFCAllocator.
I0000 00:00:1736507062.030817 2365508 gpu_helpers.cc:154] XLA backend will use up to 11934154752 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1736507062.030828 2365508 gpu_helpers.cc:154] XLA backend will use up to 11934154752 bytes on device 1 for CollectiveBFCAllocator.
2025-01-10 12:04:22.505221: E external/xla/xla/status_macros.cc:56] INTERNAL: RET_CHECK failure (external/xla/xla/service/hlo_verifier.cc:477) subgroup_size == 1 || shard_count == subgroup_size shard_count = 1, subgroup_size = 2, %all-gather.6 = s64[1]{0} all-gather(s64[1]{0} %add.5), replica_groups={}, dimensions={0}
*** Begin stack trace ***
        tsl::CurrentStackTrace()

        xla::status_macros::MakeErrorStream::Impl::GetStatus()

        xla::ShapeVerifier::HandleAllGather(xla::HloInstruction*)
        absl::lts_20230802::Status xla::HloInstruction::Visit<xla::HloInstruction*>(xla::DfsHloVisitorBase<xla::HloInstruction*>*)

        absl::lts_20230802::Status xla::HloInstruction::Accept<xla::HloInstruction*>(xla::DfsHloVisitorBase<xla::HloInstruction*>*, bool, bool, bool)
        absl::lts_20230802::Status xla::HloComputation::Accept<xla::HloInstruction*>(xla::DfsHloVisitorBase<xla::HloInstruction*>*) const
        xla::HloVerifier::Run(xla::HloModule*, absl::lts_20230802::flat_hash_set<std::basic_string_view<char, std::char_traits<char> >, absl::lts_20230802::container_internal::StringHash, absl::lts_20230802::container_internal::StringEq, std::allocator<std::basic_string_view<char, std::char_traits<char> > > > const&)
        xla::CreateModuleFromProto(xla::HloModuleProto const&, xla::HloModuleConfig const&, bool)
        xla::Service::BuildExecutable(xla::HloModuleProto const&, std::unique_ptr<xla::HloModuleConfig, std::default_delete<xla::HloModuleConfig> >, xla::Backend*, stream_executor::StreamExecutor*, xla::Compiler::CompileOptions const&, bool)
        xla::LocalService::CompileExecutables(xla::XlaComputation const&, absl::lts_20230802::Span<xla::Shape const* const>, xla::ExecutableBuildOptions const&)
        xla::LocalClient::Compile(xla::XlaComputation const&, absl::lts_20230802::Span<xla::Shape const* const>, xla::ExecutableBuildOptions const&)
        xla::PjRtStreamExecutorClient::CompileInternal(xla::XlaComputation const&, std::vector<xla::Shape const*, std::allocator<xla::Shape const*> > const&, std::function<absl::lts_20230802::StatusOr<std::pair<std::vector<xla::Shape, std::allocator<xla::Shape> >, xla::Shape> > (xla::HloModule const&)>, xla::CompileOptions)
        xla::PjRtStreamExecutorClient::Compile(xla::XlaComputation const&, xla::CompileOptions)
        xla::StreamExecutorGpuClient::Compile(xla::XlaComputation const&, xla::CompileOptions)
        torch_xla::runtime::PjRtComputationClient::Compile(std::vector<torch_xla::runtime::ComputationClient::CompileInstance, std::allocator<torch_xla::runtime::ComputationClient::CompileInstance> >)
        torch_xla::XLAGraphExecutor::Compile(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >&, absl::lts_20230802::Span<std::string const>, torch::lazy::LazyGraphExecutor::SyncTensorCollection const&, torch::lazy::LazyGraphExecutor::PostOrderData*, std::vector<torch::lazy::Value, std::allocator<torch::lazy::Value> > const&)
        torch_xla::XLAGraphExecutor::SyncTensorsGraphInternal(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >*, absl::lts_20230802::Span<std::string const>, torch::lazy::LazyGraphExecutor::SyncTensorsConfig const&, bool)
        torch_xla::XLAGraphExecutor::GetTensors(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >*)
        torch_xla::bridge::XlaCreateTensorList(c10::IListRef<at::Tensor> const&)
        torch_xla::XLANativeFunctions::_to_cpu(c10::ArrayRef<at::Tensor>)

        at::_ops::_to_cpu::call(c10::ArrayRef<at::Tensor>)

        at::native::cpu_fallback(c10::OperatorHandle const&, std::vector<c10::IValue, std::allocator<c10::IValue> >*, bool, c10::DispatchKey)
        torch_xla::xla_fallback(c10::OperatorHandle const&, std::vector<c10::IValue, std::allocator<c10::IValue> >*)
        at::native::_call_fallback_fn<&torch_xla::xla_fallback, at::_ops::_local_scalar_dense, false, c10::Scalar (at::Tensor const&)>::call(at::Tensor const&)
        torch_xla::XLANativeFunctions::_local_scalar_dense(at::Tensor const&)

        c10::Dispatcher::callBoxed(c10::OperatorHandle const&, std::vector<c10::IValue, std::allocator<c10::IValue> >*) const


        at::_ops::_local_scalar_dense::redispatch(c10::DispatchKeySet, at::Tensor const&)


        at::_ops::_local_scalar_dense::call(at::Tensor const&)
        at::native::item(at::Tensor const&)

        at::_ops::item::call(at::Tensor const&)
        int at::Tensor::item<int>() const
        c10d::verify_params_across_processes(c10::intrusive_ptr<c10d::ProcessGroup, c10::detail::intrusive_target_default_null_type<c10d::ProcessGroup> > const&, std::vector<at::Tensor, std::allocator<at::Tensor> > const&, std::optional<std::weak_ptr<c10d::Logger> > const&)



        _PyObject_MakeTpCall
        _PyEval_EvalFrameDefault
        _PyFunction_Vectorcall
        _PyEval_EvalFrameDefault
        _PyObject_FastCallDictTstate

        _PyObject_MakeTpCall
        _PyEval_EvalFrameDefault
        _PyFunction_Vectorcall
        _PyEval_EvalFrameDefault
        _PyFunction_Vectorcall
        _PyEval_EvalFrameDefault

        PyEval_EvalCode



        _PyRun_SimpleFileObject
        _PyRun_AnyFileObject
        Py_RunMain
        Py_BytesMain

        __libc_start_main
        _start
*** End stack trace ***

[rank0]: Traceback (most recent call last):
[rank0]:   File "example-xla.py", line 44, in <module>
[rank0]:     torch_xla.launch(_mp_fn)
[rank0]:   File "xla/venv/lib/python3.10/site-packages/torch_xla/torch_xla.py", line 231, in launch
[rank0]:     fn(xu.getenv_as(xenv.LOCAL_RANK, int), *args)
[rank0]:   File "xla/unet/ddp/example-xla.py", line 24, in _mp_fn
[rank0]:     model = DDP(model)
[rank0]:   File "xla/venv/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 825, in __init__
[rank0]:     _verify_param_shape_across_processes(self.process_group, parameters)
[rank0]:   File "xla/venv/lib/python3.10/site-packages/torch/distributed/utils.py", line 288, in _verify_param_shape_across_processes
[rank0]:     return dist._verify_params_across_processes(process_group, tensors, logger)
[rank0]: RuntimeError: Bad StatusOr access: INTERNAL: during context [Unknown]: RET_CHECK failure (external/xla/xla/service/hlo_verifier.cc:477) subgroup_size == 1 || shard_count == subgroup_size shard_count = 1, subgroup_size = 2, %all-gather.6 = s64[1]{0} all-gather(s64[1]{0} %add.5), replica_groups={}, dimensions={0}
E0110 12:04:23.262000 2365422 torch/distributed/elastic/multiprocessing/api.py:869] failed (exitcode: 1) local_rank: 0 (pid: 2365508) of binary: xla/venv/bin/python
Traceback (most recent call last):
  File "xla/venv/bin/torchrun", line 8, in <module>
    sys.exit(main())
  File "xla/venv/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
    return f(*args, **kwargs)
  File "xla/venv/lib/python3.10/site-packages/torch/distributed/run.py", line 919, in main
    run(args)
  File "xla/venv/lib/python3.10/site-packages/torch/distributed/run.py", line 910, in run
    elastic_launch(
  File "xla/venv/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 138, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "xla/venv/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 269, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 

On the other side, if I tried to run with both GPUs to have parallelism, it just hangs. When I cancel the execution the traceback is this, it appears to hang on torchxla.launch

  • torchrun --nnodes=1 --nproc-per-node=2 example-xla.py --epochs 1

Environment

$> pip freeze | grep torch
torch==2.5.1
torch-xla @ https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.0-cp310-cp310-manylinux_2_28_x86_64.whl#sha256=86d0a9af00fb678f903e5c4968e30dca3c50d6ce64aa33da9314b2134418ace3
torchvision==0.20.1

CUDA 12.5
Driver 555.42.06

Additional context

I can successfully execute other non-parallel xla scripts.
When I try to use multiple nodes with torchrun it also hangs, while the same command with non-xla scripts works perfectly.

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions