-
Notifications
You must be signed in to change notification settings - Fork 55
Open
Description
Hi there,
We were trying the kernel(s) for a longer training sequence and observe the following behavior after a few 10ks of steps:
- when everything is in fp32 (NATTEN, and the rest of the model, including normalization, etc) everything is fine
- when NATTEN is in bf16, and everything else is in fp32, then the system runs, but the metrics start diverging and getting worse
- when everything is in bf16, we run into NaN issues
Have you observed anything like this before? do you know whether there might be a bug in the backward/gradient accumulation or handling, or something in the dtype versions/accumulation?
Thank you!
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels