Replies: 1 comment 2 replies
-
Not sure the exact number you should expect, but with |
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
If I switch my model from running transformer blocks in a for loop to running them with
remat_scan
I eat a pretty big performance hit - 36%. See this Colab, particularly when I callrun_transformer_noscan_j
andrun_transformer_withscan_j
. Is this to be expected? In my actual model decoding is also ~3x slower, though I can't replicate that in the Colab. Might have to do with me having written a traceable version for the real model.So, is this a bug? How much of a penalty should I expect to eat using (remat_)scan? What's the best way to get a decent balance of compile time and run times? The scanless version of my model takes ~12 minutes to JIT on a single GPUs and long enough on 8 that I've always given up before it finished, so I need to do something.
Beta Was this translation helpful? Give feedback.
All reactions