Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix RuntimeError in Cross Attention Using Out-Of-Place Addition #85

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

geranium12
Copy link

Description

When training the BLT model, I encountered a runtime error caused by an in-place operation on tensor views.

The problem arises in local_models.py where the following in-place addition is performed.

    patch_embeds += patch_embeds_cross
    return patch_embeds 

Error Traceback

[rank0]: Traceback (most recent call last):
[rank0]: File "<frozen runpy>", line 198, in _run_module_as_main
[rank0]: File "<frozen runpy>", line 88, in _run_code
[rank0]: File "/mnt/storage1/geranium/blt/bytelatent/train.py", line 812, in <module>
[rank0]: main()
[rank0]: File "/mnt/storage1/geranium/blt/bytelatent/train.py", line 808, in main
[rank0]: train(train_args)
[rank0]: File "/mnt/storage1/geranium/blt/bytelatent/train.py", line 486, in train
[rank0]: loss.backward()
[rank0]: File "/mnt/storage1/geranium/blt/.venv/lib/python3.11/site-packages/torch/_tensor.py", line 581, in backward
[rank0]: torch.autograd.backward(
[rank0]: File "/mnt/storage1/geranium/blt/.venv/lib/python3.11/site-packages/torch/autograd/__init__.py", line 347, in backward
[rank0]: _engine_run_backward(
[rank0]: File "/mnt/storage1/geranium/blt/.venv/lib/python3.11/site-packages/torch/autograd/graph.py", line 825, in _engine_run_backward
[rank0]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: Output 0 of RegisterPostBackwardFunctionBackward is a view and its base or another view of its base has been modified inplace. This view is the output of a function that returns multiple views. Such functions do not allow the output views to be modified inplace. You should replace the inplace operation by an out-of-place one.

Error Traceback with torch.autograd.set_detect_anomaly(True)

/mnt/storage1/geranium/blt/.venv/lib/python3.11/site-packages/torch/autograd/graph.py:825: UserWarning: Error detected in MulBackward0. Traceback of forward call that caused the error:
 File "<frozen runpy>", line 198, in _run_module_as_main
 File "<frozen runpy>", line 88, in _run_code
 File "/mnt/storage1/geranium/blt/bytelatent/train.py", line 812, in <module>
 main()
 File "/mnt/storage1/geranium/blt/bytelatent/train.py", line 808, in main
 train(train_args)
 File "/mnt/storage1/geranium/blt/bytelatent/train.py", line 475, in train
 pred = model(
 File "/mnt/storage1/geranium/blt/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
 return self._call_impl(*args, **kwargs)
 File "/mnt/storage1/geranium/blt/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1844, in _call_impl
 return inner()
 File "/mnt/storage1/geranium/blt/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in inner
 result = forward_call(*args, **kwargs)
 File "/mnt/storage1/geranium/blt/bytelatent/model/blt.py", line 949, in forward
 (h_encoder, h_cross), cache_encoder = self.local_encoder(
 File "/mnt/storage1/geranium/blt/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
 return self._call_impl(*args, **kwargs)
 File "/mnt/storage1/geranium/blt/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
 return forward_call(*args, **kwargs)
 File "/mnt/storage1/geranium/blt/bytelatent/model/local_models.py", line 280, in forward
 patch_embeds = self.apply_cross_attention(
 File "/mnt/storage1/geranium/blt/bytelatent/model/local_models.py", line 306, in apply_cross_attention
 patch_embeds_cross = self.cross_attn_layers[layer_idx](
 File "/mnt/storage1/geranium/blt/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
 return self._call_impl(*args, **kwargs)
 File "/mnt/storage1/geranium/blt/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1844, in _call_impl
 return inner()
 File "/mnt/storage1/geranium/blt/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in inner
 result = forward_call(*args, **kwargs)
 File "/mnt/storage1/geranium/blt/bytelatent/model/latent_transformer.py", line 87, in forward
 x_norm = self.cross_attn_norm_q(x)
 File "/mnt/storage1/geranium/blt/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
 return self._call_impl(*args, **kwargs)
 File "/mnt/storage1/geranium/blt/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
 return forward_call(*args, **kwargs)
 File "/mnt/storage1/geranium/blt/.venv/lib/python3.11/site-packages/torch/nn/modules/normalization.py", line 401, in forward
 return F.rms_norm(x, self.normalized_shape, self.weight, self.eps)
 File "/mnt/storage1/geranium/blt/.venv/lib/python3.11/site-packages/torch/nn/functional.py", line 2919, in rms_norm
 return torch.rms_norm(input, normalized_shape, weight, eps)
 (Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:110.)
 return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass

Fix

Out-of-place addition creates a new tensor, preventing the runtime error during the backward pass.

    return patch_embeds + patch_embeds_cross

Versions

The environment was set up using create_env.sh.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 18, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants