Skip to content

[BUG: No operator found for memory_efficient_attention_forward with inputs #109

@Luca-Girotti

Description

@Luca-Girotti

Python Version

3.12

Pip Freeze

2025-05-06 14:33:06 (UTC) - 0:58:27 - py4j.clientserver - INFO - Received command c on object id p0
absl-py==1.0.0
accelerate==1.5.2
aiohttp==3.9.5
aiohttp-cors==0.8.1
aiosignal==1.2.0
annotated-types==0.7.0
anyio==4.2.0
argcomplete==3.6.2
argon2-cffi==21.3.0
argon2-cffi-bindings==21.2.0
arrow==1.2.3
astor==0.8.1
asttokens==2.0.5
astunparse==1.6.3
async-lru==2.0.4
attrs==23.1.0
audioread==3.0.1
azure-core==1.33.0
azure-cosmos==4.3.1
azure-identity==1.21.0
azure-storage-blob==12.23.0
azure-storage-file-datalake==12.17.0
Babel==2.11.0
backoff==2.2.1
bcrypt==3.2.0
beautifulsoup4==4.12.3
bitsandbytes==0.45.5
black==24.4.2
bleach==4.1.0
blinker==1.7.0
blis==0.7.11
boto3==1.34.69
botocore==1.34.69
Brotli==1.0.9
cachetools==5.3.3
catalogue==2.0.10
category-encoders==2.6.3
certifi==2024.6.2
cffi==1.16.0
chardet==4.0.0
charset-normalizer==2.0.4
circuitbreaker==2.1.3
click==8.1.7
cloudpathlib==0.21.0
cloudpickle==2.2.1
cmdstanpy==1.2.5
colorful==0.5.6
colorlog==6.9.0
comm==0.2.1
composer==0.29.0
confection==0.1.5
configparser==5.2.0
contourpy==1.2.0
coolname==2.2.0
cryptography==42.0.5
cycler==0.11.0
cymem==2.0.11
Cython==3.0.11
dacite==1.9.2
databricks-automl-runtime==0.2.21
databricks-feature-engineering==0.10.2
databricks-sdk==0.30.0
datasets==3.5.0
dbl-tempo==0.1.26
dbus-python==1.3.2
debugpy==1.6.7
decorator==5.1.1
deepspeed==0.16.5
defusedxml==0.7.1
Deprecated==1.2.18
dill==0.3.8
distlib==0.3.8
distro==1.9.0
distro-info==1.7+build1
dm-tree==0.1.9
docstring-to-markdown==0.11
docstring_parser==0.16
einops==0.8.1
entrypoints==0.4
evaluate==0.4.3
executing==0.8.3
facets-overview==1.1.1
Farama-Notifications==0.0.4
fastapi==0.115.12
fastjsonschema==2.21.1
fasttext-wheel==0.9.2
filelock==3.13.1
fire==0.7.0
flash_attn==2.7.4.post1
Flask==2.2.5
flatbuffers==25.2.10
fonttools==4.51.0
fqdn==1.5.1
frozenlist==1.4.0
fsspec==2023.5.0
future==0.18.3
gast==0.4.0
gitdb==4.0.11
GitPython==3.1.37
google-api-core==2.20.0
google-auth==2.21.0
google-auth-oauthlib==1.2.1
google-cloud-core==2.4.3
google-cloud-storage==2.10.0
google-crc32c==1.7.1
google-pasta==0.2.0
google-resumable-media==2.7.2
googleapis-common-protos==1.70.0
gql==3.5.2
graphql-core==3.2.4
greenlet==3.0.1
grpcio==1.60.0
grpcio-status==1.60.0
gunicorn==20.1.0
gviz-api==1.10.0
gymnasium==0.28.1
h11==0.14.0
h5py==3.11.0
hjson==3.1.0
holidays==0.54
htmlmin==0.1.12
httpcore==1.0.8
httplib2==0.20.4
httpx==0.28.1
huggingface-hub==0.29.3
idna==3.7
ImageHash==4.3.1
imageio==2.33.1
imbalanced-learn==0.12.3
importlib-metadata==6.0.0
importlib_resources==6.5.2
ipyflow-core==0.0.201
ipykernel==6.28.0
ipython==8.25.0
ipython-genutils==0.2.0
ipywidgets @ file:///databricks/.virtualenv-def/ipywidgets-7.7.2-2databricksnojsdeps-py2.py3-none-any.whl#sha256=903ead20c8d40de671853515fcad2f34b43ebf3eff80e4df3f876b8dd64c903b
isodate==0.6.1
isoduration==20.11.0
itsdangerous==2.2.0
jax-jumpy==1.0.0
jedi==0.19.1
Jinja2==3.1.4
jiter==0.9.0
jmespath==1.0.1
joblib==1.4.2
joblibspark==0.5.3
json5==0.9.6
jsonpatch==1.33
jsonpointer==3.0.0
jsonschema==4.23.0
jsonschema-specifications==2023.7.1
jupyter-events==0.10.0
jupyter-lsp==2.2.0
jupyter_client==8.6.0
jupyter_core==5.7.2
jupyter_server==2.14.1
jupyter_server_terminals==0.4.4
jupyterlab==4.0.11
jupyterlab-pygments==0.1.2
jupyterlab_server==2.25.1
keras==3.9.0
kiwisolver==1.4.4
langchain==0.3.21
langchain-core==0.3.51
langchain-text-splitters==0.3.8
langcodes==3.5.0
langsmith==0.1.133
language_data==1.3.0
launchpadlib==1.11.0
lazr.restfulclient==0.14.6
lazr.uri==1.0.6
lazy_loader==0.4
libclang==15.0.6.1
librosa==0.10.2
lightgbm==4.5.0
lightning-utilities==0.14.3
linkify-it-py==2.0.0
llvmlite==0.42.0
lz4==4.3.2
Mako==1.2.0
marisa-trie==1.2.0
Markdown==3.4.1
markdown-it-py==2.2.0
MarkupSafe==2.1.3
matplotlib==3.8.4
matplotlib-inline==0.1.6
mccabe==0.7.0
mdit-py-plugins==0.3.0
mdurl==0.1.0
memray==1.17.1
mistral_common==1.5.4
mistune==2.0.4
ml-dtypes==0.4.1
mlflow-skinny==2.21.3
mosaicml-cli==0.6.41
mosaicml-streaming==0.11.0
mpmath==1.3.0
msal==1.32.0
msal-extensions==1.3.1
msgpack==1.1.0
multidict==6.0.4
multimethod==1.12
multiprocess==0.70.16
murmurhash==1.0.12
mypy==1.10.0
mypy-extensions==1.0.0
namex==0.0.8
nbclient==0.8.0
nbconvert==7.10.0
nbformat==5.9.2
nest-asyncio==1.6.0
networkx==3.2.1
ninja==1.11.1.1
nltk==3.8.1
nodeenv==1.9.1
notebook==7.0.8
notebook_shim==0.2.3
numba==0.59.1
numpy==1.26.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-cusparselt-cu12==0.6.2
nvidia-nccl-cu12==2.19.3
nvidia-nvjitlink-cu12==12.4.127
nvidia-nvtx-cu12==12.1.105
oauthlib==3.2.0
oci==2.150.0
openai==1.69.0
opencensus==0.11.4
opencensus-context==0.1.3
opentelemetry-api==1.32.0
opentelemetry-sdk==1.32.0
opentelemetry-semantic-conventions==0.53b0
opt_einsum==3.4.0
optree==0.15.0
optuna==3.6.1
optuna-integration==3.6.0
orjson==3.10.16
overrides==7.4.0
packaging==24.1
pandas==1.5.3
pandocfilters==1.5.0
paramiko==3.4.0
parso==0.8.3
pathspec==0.10.3
patsy==0.5.6
peft==0.15.2
pexpect==4.8.0
phik==0.12.4
pillow==10.3.0
platformdirs==3.10.0
plotly==5.22.0
pluggy==1.0.0
pmdarima==2.0.4
pooch==1.8.2
preshed==3.0.9
prometheus-client==0.14.1
prompt-toolkit==3.0.43
prophet==1.1.6
proto-plus==1.26.1
protobuf==4.24.1
psutil==5.9.0
psycopg2==2.9.3
ptyprocess==0.7.0
pure-eval==0.2.2
py-cpuinfo==9.0.0
py-spy==0.4.0
pyarrow==15.0.2
pyasn1==0.4.8
pyasn1-modules==0.2.8
pybind11==2.13.6
pyccolo==0.0.65
pycparser==2.21
pydantic==2.8.2
pydantic_core==2.20.1
pyflakes==3.2.0
Pygments==2.15.1
PyGObject==3.48.2
PyJWT==2.7.0
PyNaCl==1.5.0
pyodbc==5.0.1
pyOpenSSL==24.0.0
pyparsing==3.0.9
pyright==1.1.294
pytesseract==0.3.10
python-apt==2.7.7+ubuntu4
python-dateutil==2.9.0.post0
python-editor==1.0.4
python-json-logger==2.0.7
python-lsp-jsonrpc==1.1.2
python-lsp-server==1.10.0
python-snappy==0.6.1
pytoolconfig==1.2.6
pytorch-ranger==0.1.1
pytz==2024.1
PyWavelets==1.5.0
PyYAML==6.0.1
pyzmq==25.1.2
questionary==2.1.0
ray==2.37.0
referencing==0.30.2
regex==2023.10.3
requests==2.32.2
requests-oauthlib==1.3.1
requests-toolbelt==1.0.0
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rich==13.3.5
rope==1.12.0
rpds-py==0.10.6
rsa==4.9
ruamel.yaml==0.18.10
ruamel.yaml.clib==0.2.12
s3transfer==0.10.2
safetensors==0.4.4
scikit-image==0.23.2
scikit-learn==1.4.2
scipy==1.13.1
seaborn==0.13.2
Send2Trash==1.8.2
sentence-transformers==3.4.1
sentencepiece==0.2.0
setuptools==74.0.0
shap==0.46.0
shellingham==1.5.4
simple-parsing==0.1.7
simplejson==3.17.6
six==1.16.0
slicer==0.0.8
smart-open==5.2.1
smmap==5.0.0
sniffio==1.3.0
soundfile==0.12.1
soupsieve==2.5
soxr==0.5.0.post1
spacy==3.7.5
spacy-legacy==3.0.12
spacy-loggers==1.0.5
SQLAlchemy==2.0.30
sqlparse==0.4.2
srsly==2.5.1
ssh-import-id==5.11
stack-data==0.2.0
stanio==0.5.1
starlette==0.46.2
statsmodels==0.14.2
sympy==1.13.1
tabulate==0.9.0
tangled-up-in-unicode==0.2.0
tenacity==8.2.2
tensorboard==2.17.0
tensorboard-data-server==0.7.2
tensorboard-plugin-profile==2.17.0
tensorboardX==2.6.2.2
tensorflow==2.17.0
tensorflow-estimator==2.15.0
termcolor==3.0.1
terminado==0.17.1
textual==3.1.0
tf_keras==2.17.0
thinc==8.2.5
threadpoolctl==2.2.0
tifffile==2023.4.12
tiktoken==0.7.0
tinycss2==1.2.1
tokenize-rt==4.2.1
tokenizers==0.21.0
tomli==2.0.1
torch==2.2.0
torch-optimizer==0.3.0
torcheval==0.0.7
torchmetrics==1.6.0
torchvision==0.21.0+cu124
tornado==6.4.1
tqdm==4.66.4
traitlets==5.14.3
transformers==4.50.2
triton==2.2.0
typeguard==4.4.2
typer==0.15.2
types-protobuf==3.20.3
types-psutil==5.9.0
types-pytz==2023.3.1.1
types-PyYAML==6.0.0
types-requests==2.31.0.0
types-setuptools==68.0.0.0
types-six==1.16.0
types-urllib3==1.26.25.14
typing_extensions==4.11.0
uc-micro-py==1.0.1
ujson==5.10.0
unattended-upgrades==0.1
uri-template==1.3.0
urllib3==1.26.16
uvicorn==0.34.1
validators==0.34.0
virtualenv==20.26.2
visions==0.7.5
wadllib==1.3.6
wasabi==1.1.3
wcwidth==0.2.5
weasel==0.4.1
webcolors==24.11.1
webencodings==0.5.1
websocket-client==1.8.0
websockets==11.0.3
Werkzeug==3.0.3
whatthepatch==1.0.2
wheel==0.43.0
wordcloud==1.9.4
wrapt==1.14.1
xformers==0.0.24
xgboost==2.0.3
xgboost-ray==0.1.19
xxhash==3.4.1
yapf==0.33.0
yarl==1.9.3
ydata-profiling==4.9.0
zipp==3.17.0
zstd==1.5.5.1

Reproduction Steps

I am on Databricks and these are the 2 cells that I am currently running:

  1. %pip install -r /dbfs/FileStore/mistral-finetunement/requirements.txt
    %pip install accelerate peft bitsandbytes transformers fire

  2. import sys
    import os
    import runpy
    import torch

os.environ["FLASH_ATTENTION_FORCE_DISABLED"] = "1"

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"

sys.path.insert(0, "/dbfs/FileStore/mistral-finetunement")
print("Using device:", "cuda" if torch.cuda.is_available() else "cpu")

os.chdir("/dbfs/FileStore/mistral-finetunement")

import train
train.train("/dbfs/FileStore/mistral-finetunement/example/7B.yaml")

Expected Behavior

whenever I am running/executing the train file, it runs for like 40 mins or so, loads the model on 1 cpu, shards the models, loads the dataset and then reaches this point and then basically gives me the following error:

NotImplementedError: No operator found for memory_efficient_attention_forward with inputs:
query : shape=(1, 32768, 32, 128) (torch.float16)
key : shape=(1, 32768, 32, 128) (torch.float16)
value : shape=(1, 32768, 32, 128) (torch.float16)
attn_bias : <class 'xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask'>
p : 0.0
[email protected] is not supported because:
xFormers wasn't build with CUDA support
operator wasn't built - see python -m xformers.info for more info
tritonflashattF is not supported because:
xFormers wasn't build with CUDA support
attn_bias type is <class 'xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask'>
operator wasn't built - see python -m xformers.info for more info
triton is not available
Only work on pre-MLIR triton for now
cutlassF is not supported because:
xFormers wasn't build with CUDA support
operator wasn't built - see python -m xformers.info for more info
smallkF is not supported because:
max(query.shape[-1] != value.shape[-1]) > 32
xFormers wasn't build with CUDA support
dtype=torch.float16 (supported: {torch.float32})
attn_bias type is <class 'xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask'>
operator wasn't built - see python -m xformers.info for more info
unsupported embed per head: 128
File , line 17
14 os.chdir("/dbfs/FileStore/mistral-finetunement")
16 import train
---> 17 train.train("/dbfs/FileStore/mistral-finetunement/example/7B.yaml")
File /local_disk0/.ephemeral_nfs/envs/pythonEnv-ef141b96-ec7d-42b5-90a3-f628ec089a4c/lib/python3.12/site-packages/xformers/ops/fmha/dispatch.py:63, in _run_priority_list(name, priority_list, inp)
61 for op, not_supported in zip(priority_list, not_supported_reasons):
62 msg += "\n" + _format_not_supported_reasons(op, not_supported)
---> 63 raise NotImplementedError(msg)

Additional Context

No response

Suggested Solutions

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions