Skip to content

XLNet: relative_positional_encoding computes on CPU every forward pass (missing device= in torch.arange) #44737

@mvstrauss

Description

@mvstrauss

System Info

  • transformers version: confirmed on 4.30.2, still present on main as of 2026-03-16
  • Platform: Linux (tested on NVIDIA/AMD GPUs)
  • PyTorch: 2.x
  • Python: 3.10+

Who can help?

@ArthurZucker @Rocketknight1

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

XLNetModel.relative_positional_encoding (in src/transformers/models/xlnet/modeling_xlnet.py, lines ~940–976) creates all intermediate tensors on CPU because torch.arange() is called without device=. The entire sinusoidal positional encoding computation (arange → pow → einsum → sin/cos → cat → expand) runs on CPU every forward pass, with only the final result being copied to GPU via .to(output_h.device) on line ~1143.

There are four affected torch.arange calls:

# Line 942 — freq_seq on CPU
freq_seq = torch.arange(0, self.d_model, 2.0, dtype=torch.int64).float()

# Line 955 — fwd_pos_seq on CPU (bi_data path)
fwd_pos_seq = torch.arange(beg, end, -1.0, dtype=torch.int64).float()

# Line 956 — bwd_pos_seq on CPU (bi_data path)
bwd_pos_seq = torch.arange(-beg, -end, 1.0, dtype=torch.int64).float()

# Line 971 — fwd_pos_seq on CPU (non-bi_data path)
fwd_pos_seq = torch.arange(beg, end, -1.0, dtype=torch.int64).float()

Since none of these specify device=, all downstream operations (torch.pow, torch.einsum, torch.sin, torch.cos, torch.cat, .expand) also execute on CPU. The .to(output_h.device) call in forward() only moves the final tensor to GPU — it does not retroactively move the computation.

This is called on every forward pass (not cached), causing:

  • Unnecessary CPU computation that the GPU could perform faster
  • A CPU→GPU memory copy and synchronization point per forward pass
  • Measurable training/inference slowdown that scales with sequence length

Expected behavior

All tensor creation inside relative_positional_encoding should happen on the model's device so the full computation stays on GPU.

Suggested fix

Pass the model's device to all torch.arange() calls. The device can be obtained from self.word_embedding.weight.device (or equivalently self.device). Minimal diff:

 def relative_positional_encoding(self, qlen, klen, bsz=None):
     # create relative positional encoding.
-    freq_seq = torch.arange(0, self.d_model, 2.0, dtype=torch.int64).float()
+    freq_seq = torch.arange(0, self.d_model, 2.0, dtype=torch.int64, device=self.device).float()
     inv_freq = 1 / torch.pow(10000, (freq_seq / self.d_model))

     if self.attn_type == "bi":
@@ -954,8 +954,8 @@
         raise ValueError(f"Unknown `attn_type` {self.attn_type}.")

     if self.bi_data:
-        fwd_pos_seq = torch.arange(beg, end, -1.0, dtype=torch.int64).float()
-        bwd_pos_seq = torch.arange(-beg, -end, 1.0, dtype=torch.int64).float()
+        fwd_pos_seq = torch.arange(beg, end, -1.0, dtype=torch.int64, device=self.device).float()
+        bwd_pos_seq = torch.arange(-beg, -end, 1.0, dtype=torch.int64, device=self.device).float()

         if self.clamp_len > 0:
             fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
@@ -968,7 +968,7 @@

         pos_emb = torch.cat([fwd_pos_emb, bwd_pos_emb], dim=1)
     else:
-        fwd_pos_seq = torch.arange(beg, end, -1.0, dtype=torch.int64).float()
+        fwd_pos_seq = torch.arange(beg, end, -1.0, dtype=torch.int64, device=self.device).float()
         if self.clamp_len > 0:
             fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
         pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz)

After this fix, the existing pos_emb = pos_emb.to(output_h.device) in forward() becomes a no-op (same device), which could optionally be removed as well.

Note: XLNetModel inherits from PreTrainedModel which provides the .device property, so no additional device resolution logic is needed.

We are currently working around this with a runtime monkey-patch on our side but would prefer to see this fixed upstream.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions