Skip to content

Why autocast in checkpoint_activation uses default dtype? #1199

@psolt-leonardo

Description

@psolt-leonardo

Hi FairScale team,

Thanks for your work on this project!

I was reading through checkpoint_activations.py, and I noticed that when autocast is enabled, recomputation is wrapped roughly like this:

if torch.is_autocast_enabled():
    with torch.cuda.amp.autocast():

This means it always uses the default autocast dtype (float16), even if the parent autocast context is using a different dtype like bfloat16 or float32.

I'm wondering: is there a reason for not explicitly matching the parent autocast dtype here (e.g., via torch.get_autocast_gpu_dtype())? I'm asking because in my understanding this could potentially lead to:

  1. Underflow/overflow when the parent context is using bfloat16, but recomputation uses float16 (due to narrower dynamic range).
  2. Unwanted downcasting when the parent context is using float32, but recomputation switches to float16, possibly introducing precision loss.

Would it make sense to replace this with:

with torch.cuda.amp.autocast(dtype=torch.get_autocast_gpu_dtype()):

to preserve consistency with the parent context?

Curious to understand the reasoning here — happy to open a PR if this is something that needs fixing.

Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions