Skip to content

NaN on longer training sequence #307

@coconutruben

Description

@coconutruben

Hi there,

We were trying the kernel(s) for a longer training sequence and observe the following behavior after a few 10ks of steps:

  • when everything is in fp32 (NATTEN, and the rest of the model, including normalization, etc) everything is fine
  • when NATTEN is in bf16, and everything else is in fp32, then the system runs, but the metrics start diverging and getting worse
  • when everything is in bf16, we run into NaN issues

Have you observed anything like this before? do you know whether there might be a bug in the backward/gradient accumulation or handling, or something in the dtype versions/accumulation?

Thank you!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions