Open
Description
I'm using NNX for a toy transformer on Wikitext-103, and I'm observing that one in every ~100 steps there's a step that takes much much longer (on the order of 2 seconds vs 0.02 seconds). I'm managed to trakc down the culprit with a profile, and it seems that there's sone NNX internal machinery in nnx.split
that's taking the bulk of the time:
Is there anything NNX-related that could be causing this to take a long time?