Skip to content

RoPE removed in GraniteMoEHybrid models in v5 #42862

@avihu111

Description

@avihu111

System Info

  • transformers version: 5.0.0.dev0
  • Platform: Linux-5.14.0-503.23.1.el9_5.x86_64-x86_64-with-glibc2.34
  • Python version: 3.10.15
  • Huggingface_hub version: 1.2.2
  • Safetensors version: 0.4.5
  • Accelerate version: 1.11.0
  • Accelerate config: not found
  • DeepSpeed version: 0.17.5
  • PyTorch version (accelerator?): 2.7.0+cu128 (CUDA)
  • Using distributed or parallel set-up in script?:
  • Using GPU in script?:
  • GPU type: NVIDIA H100 80GB HBM3

Who can help?

Hi @ArthurZucker @Cyrilvallez

When running granite-4.0-micro with transformers v5 (or the latest main branch), I get degraded results.
In the working version v4.57.3, the code applies RoPE (the code supports both RoPE for dense-transformer models and no position embeddings for hybrid models)
See the relevant code-piece here

        cos, sin = position_embeddings if position_embeddings is not None else (None, None)
        if position_embeddings is not None:
            query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

When upgrading to v5, the lines that apply RoPE are missing - see here

Re-applying RoPE in the v5 code version resolves the degraded performance when using granite-4.0-micro.
Happy to submit a PR to resolve this.
CC: @gabe-l-hart @shawntan @alex-jw-brooks

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Run the official code in https://huggingface.co/ibm-granite/granite-4.0-micro

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

device = "cuda"
model_path = "ibm-granite/granite-4.0-micro"
tokenizer = AutoTokenizer.from_pretrained(model_path)
# drop device_map if running on CPU
model = AutoModelForCausalLM.from_pretrained(model_path, device_map=device)
model.eval()
# change input text as desired
chat = [
    { "role": "user", "content": "Please list one IBM Research laboratory located in the United States. You should only output its name and location." },
]
chat = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
# tokenize the text
input_tokens = tokenizer(chat, return_tensors="pt").to(device)
# generate output tokens
output = model.generate(**input_tokens, 
                        max_new_tokens=100)
# decode output tokens into text
output = tokenizer.batch_decode(output)
# print output
print(output[0])

Expected behavior

<|start_of_role|>system<|end_of_role|>You are a helpful assistant. Please ensure responses are professional, accurate, and safe.<|end_of_text|>
<|start_of_role|>user<|end_of_role|>Please list one IBM Research laboratory located in the United States. You should only output its name and location.<|end_of_text|>
<|start_of_role|>assistant<|end_of_role|>Almaden Research Center, San Jose, California<|end_of_text|>

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions