Define shared memory for oneAPI (Flash attention on Intel GPUs) #13
Closed
AntonOresten wants to merge 1 commit intoFluxML:masterfrom
Closed
Define shared memory for oneAPI (Flash attention on Intel GPUs) #13AntonOresten wants to merge 1 commit intoFluxML:masterfrom
AntonOresten wants to merge 1 commit intoFluxML:masterfrom
Conversation
Contributor
Author
|
Getting a lot of segmentation faults when trying to run the unit tests. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
I know oneAPI is generally minimally supported in Julia, but it still works surprisingly well with KernelAbstractions.jl, and it seems everything except flash attention runs already.
As I have an Arc A770 laying around, I wanted to be able to try some DL stuff on my home desktop.
But again, since oneAPI is minimally supported across the ecosystem,
naive_attentionfrom the unit tests doesn't run because of theNNlib.batched_mul. I've opened a PR for partial oneAPI support: FluxML/NNlib.jl#644Flash attention benchmarks
Since naive_attention doesn't work, I made a special script for just flash attention. Seems roughly within an order of magnitude above benchmarks in #11.