Skip to content

fix(vqbet): use in-place fill_ to avoid overwriting DDP GPU buffers with CPU tensors#3128

Open
Altman-conquer wants to merge 3 commits intohuggingface:mainfrom
Altman-conquer:fix/vqbet-ddp-cpu-buffer-overwrite
Open

fix(vqbet): use in-place fill_ to avoid overwriting DDP GPU buffers with CPU tensors#3128
Altman-conquer wants to merge 3 commits intohuggingface:mainfrom
Altman-conquer:fix/vqbet-ddp-cpu-buffer-overwrite

Conversation

@Altman-conquer
Copy link

Summary

Fixes the VQBeT DDP training crash: RuntimeError: No backend type associated with device type cpu when the VQ discretization phase starts (around step 1000).

Root cause

In modeling_vqbet.py, when n_vqvae_training_steps is reached, the code sets:

self.vqvae_model.discretized = torch.tensor(True)
self.vqvae_model.vq_layer.freeze_codebook = torch.tensor(True)

Both discretized and freeze_codebook are registered buffers (via register_buffer()), so they live on the model device (e.g. GPU). Assigning torch.tensor(True) replaces these buffers with new tensors created on CPU by default. On the next forward pass, DistributedDataParallel._sync_buffers() tries to broadcast them with NCCL, which only supports GPU tensors, hence the error.

Fix

Update the buffers in-place instead of replacing them, so device and DDP registration are preserved:

self.vqvae_model.discretized.fill_(True)
self.vqvae_model.vq_layer.freeze_codebook.fill_(True)

Testing

  • Logic unchanged: we only set the same boolean flags in-place.
  • No new dependencies; behavior is equivalent for single-GPU and multi-GPU once buffers stay on the correct device.

Fixes #3127

…ith CPU tensors

When VQ discretization phase completes, the code was overwriting
register_buffer('discretized') and register_buffer('freeze_codebook')
with torch.tensor(True), which is created on CPU. DDP then fails in
_sync_buffers() with: RuntimeError: No backend type associated with
device type cpu. Fix by updating the buffers in-place with .fill_(True)
so device and registration are preserved.

Made-with: Cursor
Copilot AI review requested due to automatic review settings March 11, 2026 06:50
@github-actions github-actions bot added the policies Items related to robot policies label Mar 11, 2026
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes a DDP crash in the VQBeT discretization phase by ensuring registered boolean buffers remain on the model device (e.g., CUDA) instead of being replaced with newly-created CPU tensors.

Changes:

  • Update discretized and freeze_codebook registered buffers using in-place fill_(True) to preserve device placement and DDP buffer registration.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

…scretization

Verifies that discretize() updates the 'discretized' and 'freeze_codebook'
registered buffers in-place (via fill_()) rather than replacing them with new
CPU tensors. The test checks data_ptr() identity and that the tensors remain
registered buffers after the call. This prevents regressions of the DDP fix.

Made-with: Cursor
@github-actions github-actions bot added the tests Problems with test coverage, failures, or improvements to testing label Mar 11, 2026
…fter discretize()

Directly catches the original DDP failure mode: when buffers are replaced with
torch.tensor(True) they land on CPU, causing NCCL to raise 'No backend type
associated with device type cpu' in _sync_buffers(). The GPU test places the
model on cuda:0 and asserts both buffers remain on CUDA after discretization.

Made-with: Cursor
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

policies Items related to robot policies tests Problems with test coverage, failures, or improvements to testing

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] VQBeT DDP training crashes with "No backend type associated with device type cpu" when discretization starts

2 participants