fix(vqbet): use in-place fill_ to avoid overwriting DDP GPU buffers with CPU tensors#3128
Open
Altman-conquer wants to merge 3 commits intohuggingface:mainfrom
Open
Conversation
…ith CPU tensors
When VQ discretization phase completes, the code was overwriting
register_buffer('discretized') and register_buffer('freeze_codebook')
with torch.tensor(True), which is created on CPU. DDP then fails in
_sync_buffers() with: RuntimeError: No backend type associated with
device type cpu. Fix by updating the buffers in-place with .fill_(True)
so device and registration are preserved.
Made-with: Cursor
Contributor
There was a problem hiding this comment.
Pull request overview
Fixes a DDP crash in the VQBeT discretization phase by ensuring registered boolean buffers remain on the model device (e.g., CUDA) instead of being replaced with newly-created CPU tensors.
Changes:
- Update
discretizedandfreeze_codebookregistered buffers using in-placefill_(True)to preserve device placement and DDP buffer registration.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
…scretization Verifies that discretize() updates the 'discretized' and 'freeze_codebook' registered buffers in-place (via fill_()) rather than replacing them with new CPU tensors. The test checks data_ptr() identity and that the tensors remain registered buffers after the call. This prevents regressions of the DDP fix. Made-with: Cursor
…fter discretize() Directly catches the original DDP failure mode: when buffers are replaced with torch.tensor(True) they land on CPU, causing NCCL to raise 'No backend type associated with device type cpu' in _sync_buffers(). The GPU test places the model on cuda:0 and asserts both buffers remain on CUDA after discretization. Made-with: Cursor
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Fixes the VQBeT DDP training crash:
RuntimeError: No backend type associated with device type cpuwhen the VQ discretization phase starts (around step 1000).Root cause
In
modeling_vqbet.py, whenn_vqvae_training_stepsis reached, the code sets:Both
discretizedandfreeze_codebookare registered buffers (viaregister_buffer()), so they live on the model device (e.g. GPU). Assigningtorch.tensor(True)replaces these buffers with new tensors created on CPU by default. On the next forward pass,DistributedDataParallel._sync_buffers()tries to broadcast them with NCCL, which only supports GPU tensors, hence the error.Fix
Update the buffers in-place instead of replacing them, so device and DDP registration are preserved:
Testing
Fixes #3127