Replies: 1 comment 2 replies
-
Not sure the exact number you should expect, but with |
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
If I switch my model from running transformer blocks in a for loop to running them with
remat_scan
I eat a pretty big performance hit - 36%. See this Colab, particularly when I callrun_transformer_noscan_j
andrun_transformer_withscan_j
. Is this to be expected? In my actual model decoding is also ~3x slower, though I can't replicate that in the Colab. Might have to do with me having written a traceable version for the real model.So, is this a bug? How much of a penalty should I expect to eat using (remat_)scan? What's the best way to get a decent balance of compile time and run times? The scanless version of my model takes ~12 minutes to JIT on a single GPUs and long enough on 8 that I've always given up before it finished, so I need to do something.
Beta Was this translation helpful? Give feedback.
All reactions