Skip to content

Conversation

@wasertech
Copy link

What does this PR do?

I cannot infer this model architecture on Turing. This fixes the issue at the root.

See #42371, vllm-project/vllm#29349, vllm-project/vllm#30635 and https://huggingface.co/swiss-ai/Apertus-8B-Instruct-2509/discussions/21 for context.

Fixes RuntimeError: expected m1 and m2 to have the same dtype, but got: float != c10::Half

Issue reproduction:

import torch
from torch import nn
from transformers.models.apertus.configuration_apertus import ApertusConfig
from transformers.models.apertus.modeling_apertus import ApertusMLP
import sys

def test_apertus_mlp_crash():
    print(f"Python executable: {sys.executable}")
    print("Testing ApertusMLP crash on float16...")
    
    # 1. Setup Config with float16 and xielu
    config = ApertusConfig(
        hidden_size=64,
        intermediate_size=128,
        hidden_act="xielu",
        torch_dtype="float16" # Simulating the runner or user choice
    )
    
    # 2. Instantiate MLP
    # We might need to manually set default dtype if the model doesn't pick it up from config immediately
    # usage in vllm often sets the torch default dtype or casts the model
    torch.set_default_dtype(torch.float16) # Simulating "float16 hardware" env
    
    try:
        mlp = ApertusMLP(config)
        
        # Check act_fn dtype
        if hasattr(mlp.act_fn, 'beta'):
             print(f"XIELU beta dtype: {mlp.act_fn.beta.dtype}")
        
        # 3. Create Input
        x = torch.randn(1, 64, dtype=torch.float16)
        
        # 4. Forward Pass
        output = mlp(x)
        print("Forward pass successful!")
        print(f"Output dtype: {output.dtype}")
        
    except RuntimeError as e:
        print("\nCaught expected RuntimeError:")
        print(e)
    except Exception as e:
        print(f"\nCaught unexpected exception: {type(e)}")
        print(e)
    finally:
         torch.set_default_dtype(torch.float32) # Cleanup

if __name__ == "__main__":
    test_apertus_mlp_crash()

Output Without this patch:

Python executable: /home/waser/Projets/Transformers/transformers/venv/bin/python
Testing ApertusMLP crash on float16...
CUDA-fused xIELU not available (No module named 'xielu') – falling back to a Python version.
For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`
XIELU beta dtype: torch.bfloat16

Caught expected RuntimeError:
expected m1 and m2 to have the same dtype, but got: float != c10::Half

Output with this patch:

Python executable: /home/waser/Projets/Transformers/transformers/venv/bin/python
Testing ApertusMLP crash on float16...
CUDA-fused xIELU not available (No module named 'xielu') – falling back to a Python version.
For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`
`torch_dtype` is deprecated! Use `dtype` instead!
XIELU beta dtype: torch.float16
Forward pass successful!
Output dtype: torch.float16

Big thanks to everyone who helped to shape this patch into light.

wasertech

This comment was marked as resolved.

@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: apertus

Initialize XIELU activation with correct dtype from config (using config.dtype instead of default bfloat16) to prevent promotion to float32 and subsequent crashes on Turing/float16 GPUs.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant