Skip to content

Slow training step occasionally due to slow graph flatten #4336

Open
@kriscao-cohere

Description

@kriscao-cohere

I'm using NNX for a toy transformer on Wikitext-103, and I'm observing that one in every ~100 steps there's a step that takes much much longer (on the order of 2 seconds vs 0.02 seconds). I'm managed to trakc down the culprit with a profile, and it seems that there's sone NNX internal machinery in nnx.split that's taking the bulk of the time:

Image

Is there anything NNX-related that could be causing this to take a long time?

Metadata

Metadata

Assignees

Labels

Priority: P1 - soonResponse within 5 business days. Resolution within 30 days. (Assignee required)

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions