-
Notifications
You must be signed in to change notification settings - Fork 32.5k
Description
System Info
transformersversion: confirmed on 4.30.2, still present onmainas of 2026-03-16- Platform: Linux (tested on NVIDIA/AMD GPUs)
- PyTorch: 2.x
- Python: 3.10+
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
XLNetModel.relative_positional_encoding (in src/transformers/models/xlnet/modeling_xlnet.py, lines ~940–976) creates all intermediate tensors on CPU because torch.arange() is called without device=. The entire sinusoidal positional encoding computation (arange → pow → einsum → sin/cos → cat → expand) runs on CPU every forward pass, with only the final result being copied to GPU via .to(output_h.device) on line ~1143.
There are four affected torch.arange calls:
# Line 942 — freq_seq on CPU
freq_seq = torch.arange(0, self.d_model, 2.0, dtype=torch.int64).float()
# Line 955 — fwd_pos_seq on CPU (bi_data path)
fwd_pos_seq = torch.arange(beg, end, -1.0, dtype=torch.int64).float()
# Line 956 — bwd_pos_seq on CPU (bi_data path)
bwd_pos_seq = torch.arange(-beg, -end, 1.0, dtype=torch.int64).float()
# Line 971 — fwd_pos_seq on CPU (non-bi_data path)
fwd_pos_seq = torch.arange(beg, end, -1.0, dtype=torch.int64).float()Since none of these specify device=, all downstream operations (torch.pow, torch.einsum, torch.sin, torch.cos, torch.cat, .expand) also execute on CPU. The .to(output_h.device) call in forward() only moves the final tensor to GPU — it does not retroactively move the computation.
This is called on every forward pass (not cached), causing:
- Unnecessary CPU computation that the GPU could perform faster
- A CPU→GPU memory copy and synchronization point per forward pass
- Measurable training/inference slowdown that scales with sequence length
Expected behavior
All tensor creation inside relative_positional_encoding should happen on the model's device so the full computation stays on GPU.
Suggested fix
Pass the model's device to all torch.arange() calls. The device can be obtained from self.word_embedding.weight.device (or equivalently self.device). Minimal diff:
def relative_positional_encoding(self, qlen, klen, bsz=None):
# create relative positional encoding.
- freq_seq = torch.arange(0, self.d_model, 2.0, dtype=torch.int64).float()
+ freq_seq = torch.arange(0, self.d_model, 2.0, dtype=torch.int64, device=self.device).float()
inv_freq = 1 / torch.pow(10000, (freq_seq / self.d_model))
if self.attn_type == "bi":
@@ -954,8 +954,8 @@
raise ValueError(f"Unknown `attn_type` {self.attn_type}.")
if self.bi_data:
- fwd_pos_seq = torch.arange(beg, end, -1.0, dtype=torch.int64).float()
- bwd_pos_seq = torch.arange(-beg, -end, 1.0, dtype=torch.int64).float()
+ fwd_pos_seq = torch.arange(beg, end, -1.0, dtype=torch.int64, device=self.device).float()
+ bwd_pos_seq = torch.arange(-beg, -end, 1.0, dtype=torch.int64, device=self.device).float()
if self.clamp_len > 0:
fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
@@ -968,7 +968,7 @@
pos_emb = torch.cat([fwd_pos_emb, bwd_pos_emb], dim=1)
else:
- fwd_pos_seq = torch.arange(beg, end, -1.0, dtype=torch.int64).float()
+ fwd_pos_seq = torch.arange(beg, end, -1.0, dtype=torch.int64, device=self.device).float()
if self.clamp_len > 0:
fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz)After this fix, the existing pos_emb = pos_emb.to(output_h.device) in forward() becomes a no-op (same device), which could optionally be removed as well.
Note: XLNetModel inherits from PreTrainedModel which provides the .device property, so no additional device resolution logic is needed.
We are currently working around this with a runtime monkey-patch on our side but would prefer to see this fixed upstream.