-
Notifications
You must be signed in to change notification settings - Fork 295
Open
Description
Hi FairScale team,
Thanks for your work on this project!
I was reading through checkpoint_activations.py, and I noticed that when autocast is enabled, recomputation is wrapped roughly like this:
if torch.is_autocast_enabled():
with torch.cuda.amp.autocast():
This means it always uses the default autocast dtype (float16), even if the parent autocast context is using a different dtype like bfloat16 or float32.
I'm wondering: is there a reason for not explicitly matching the parent autocast dtype here (e.g., via torch.get_autocast_gpu_dtype())? I'm asking because in my understanding this could potentially lead to:
- Underflow/overflow when the parent context is using
bfloat16, but recomputation usesfloat16(due to narrower dynamic range). - Unwanted downcasting when the parent context is using
float32, but recomputation switches tofloat16, possibly introducing precision loss.
Would it make sense to replace this with:
with torch.cuda.amp.autocast(dtype=torch.get_autocast_gpu_dtype()):
to preserve consistency with the parent context?
Curious to understand the reasoning here — happy to open a PR if this is something that needs fixing.
Thanks!
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels