-
Notifications
You must be signed in to change notification settings - Fork 3.9k
Description
Description
When training VQBeT with DistributedDataParallel (DDP) (e.g. multi-GPU), training crashes right after the VQ discretization phase starts (around step 1000), with:
RuntimeError: No backend type associated with device type cpu
The crash occurs during DistributedDataParallel._sync_buffers(), when NCCL tries to synchronize buffers across ranks.
Environment
- LeRobot: main (or v0.5.0)
- Training:
lerobot-train --policy=vqbet ...with more than one GPU (DDP) - Python: 3.10
- PyTorch: with CUDA and NCCL
Steps to reproduce
- Train VQBeT with 2+ GPUs so that DDP is used.
- Let training run until the VQ phase finishes (you see "Finished discretizing action data!" in the logs).
- On the next
policy.forward(batch)call, the process crashes.
Observed traceback
[rank0]: File ".../lerobot/scripts/lerobot_train.py", line 403, in train
[rank0]: train_tracker, output_dict = update_policy(...)
[rank0]: File ".../lerobot/scripts/lerobot_train.py", line 115, in update_policy
[rank0]: loss, output_dict = policy.forward(batch)
[rank0]: File ".../torch/nn/parallel/distributed.py", line 1662, in forward
[rank0]: inputs, kwargs = self._pre_forward(*inputs, **kwargs)
[rank0]: File ".../torch/nn/parallel/distributed.py", line 1558, in _pre_forward
[rank0]: self._sync_buffers()
[rank0]: File ".../torch/nn/parallel/distributed.py", line 2195, in _sync_buffers
[rank0]: self._sync_module_buffers(authoritative_rank)
...
[rank0]: RuntimeError: No backend type associated with device type cpu
Root cause
In src/lerobot/policies/vqbet/modeling_vqbet.py, two buffers are registered with register_buffer() (so they live on the model device, e.g. GPU):
self.vqvae_model.discretized(around line 775:register_buffer("discretized", torch.tensor(False)))self.vqvae_model.vq_layer.freeze_codebook(invqbet_utils.py:register_buffer("freeze_codebook", torch.tensor(False)))
When the discretization phase finishes (around lines 470–471), the code overwrites these buffers with new tensors:
self.vqvae_model.discretized = torch.tensor(True)
self.vqvae_model.vq_layer.freeze_codebook = torch.tensor(True)torch.tensor(True) is created on CPU by default. So the registered buffers are replaced by CPU tensors. On the next forward pass, DDP’s _sync_buffers() tries to broadcast these buffers with NCCL, which only supports GPU tensors, hence: No backend type associated with device type cpu.
Suggested fix
Do not replace the buffers; update them in-place so device and registration are preserved:
self.vqvae_model.discretized.fill_(True)
self.vqvae_model.vq_layer.freeze_codebook.fill_(True)I can open a PR with this change if maintainers are okay with this approach.