Skip to content

[BUG: inference with Blackwell GPU throws invalid argument in flash-attention/flash_fwd_launch_template.h] #250

@koiker

Description

@koiker

Python -VV

Python 3.12.11 (main, Jul 20 2025, 00:11:56) [GCC 13.3.0]

Pip Freeze

accelerate==1.9.0
annotated-types==0.7.0
attrs==25.3.0
certifi==2025.7.14
charset-normalizer==3.4.2
defusedxml==0.7.1
docstring_parser==0.16
filelock==3.13.1
fire==0.7.0
fsspec==2024.6.1
Glances==4.3.3
hf-xet==1.1.5
huggingface-hub==0.33.4
idna==3.10
Jinja2==3.1.4
jsonschema==4.25.0
jsonschema-specifications==2025.4.1
MarkupSafe==2.1.5
mistral_common==1.8.1
mistral_inference==1.6.0
mpmath==1.3.0
networkx==3.3
numpy==2.1.2
nvidia-cublas-cu12==12.8.3.14
nvidia-cuda-cupti-cu12==12.8.57
nvidia-cuda-nvrtc-cu12==12.8.61
nvidia-cuda-runtime-cu12==12.8.57
nvidia-cudnn-cu12==9.7.1.26
nvidia-cufft-cu12==11.3.3.41
nvidia-cufile-cu12==1.13.0.11
nvidia-curand-cu12==10.3.9.55
nvidia-cusolver-cu12==11.7.2.55
nvidia-cusparse-cu12==12.5.7.53
nvidia-cusparselt-cu12==0.6.3
nvidia-ml-py==12.575.51
nvidia-nccl-cu12==2.26.2
nvidia-nvjitlink-cu12==12.8.61
nvidia-nvtx-cu12==12.8.55
nvitop==1.5.1
packaging==25.0
peft==0.16.0
pillow==11.0.0
psutil==7.0.0
pycountry==24.6.1
pydantic==2.11.7
pydantic-extra-types==2.10.5
pydantic_core==2.33.2
PyYAML==6.0.2
referencing==0.36.2
regex==2024.11.6
requests==2.32.4
rpds-py==0.26.0
safetensors==0.5.3
sentencepiece==0.2.0
setuptools==70.2.0
shtab==1.7.2
simple-parsing==0.1.7
sympy==1.13.3
termcolor==3.1.0
tiktoken==0.9.0
tokenizers==0.21.2
torch==2.7.1+cu128
torchaudio==2.7.1+cu128
torchvision==0.22.1+cu128
tqdm==4.67.1
transformers==4.53.2
triton==3.3.1
typing-inspection==0.4.1
typing_extensions==4.12.2
urllib3==2.5.0
xformers==0.0.31.post1

Reproduction Steps

Install the required dependencies, copy a Mistral model to a local folder and invoke:
mistral-chat <model_path> --instruct --max_tokens 256

I used Devstral-Small-2507 as the model to load and the prompt was: Write hello world in python

Expected Behavior

A simple response like:
print("Hello, World!)

Additional Context

Here is the traceback when executed the prompt.

/.venv/lib/python3.12/site-packages/mistral_common/tokens/tokenizers/tekken.py:337: FutureWarning: The attributed special_token_policy is deprecated and will be removed in 1.10.0. Please pass a special token policy explicitly to the relevant methods.
warnings.warn(
Prompt: write hello world in python
CUDA error (/__w/xformers/xformers/third_party/flash-attention/hopper/flash_fwd_launch_template.h:188): invalid argument

From my research, the third_party/flash-attention is not working correctly with the new generation of GPUs (RTX 50xx or RTX Pro)

Suggested Solutions

Make flash attention optional and/or update flash-attention from the third-party folder.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions