Skip to content

Commit

Permalink
Add a stub for host offloading docs (pytorch#8656)
Browse files Browse the repository at this point in the history
  • Loading branch information
tengyifei authored Jan 30, 2025
1 parent 93a2ba6 commit e583c2c
Showing 1 changed file with 18 additions and 2 deletions.
20 changes: 18 additions & 2 deletions examples/host_offloading/README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
This directory will contain a self-contained example for host offloading
by the time of the 2.6 release.
## Host offloading example

When doing reverse-mode automatic differentiation, many tensors are saved
during the forward pass to be used to compute the gradient during the backward pass.
Previously you could use `torch_xla.utils.checkpoint` to discard tensors that's easy
to recompute later, called "checkpointing" or "rematerialization". Now PyTorch/XLA
also supports a technique called "host offloading", i.e. moving the tensor to host
and moving them back, adding another tool in the arsenal to save memory. Use
`torch_xla.experimental.stablehlo_custom_call.place_to_host` to move a tensor to host
and `torch_xla.experimental.stablehlo_custom_call.place_to_device` to move a tensor
back to the device. For example, you can use this to move intermediate activations
to host during a forward pass, and move those activations back to device during
the corresponding backward pass.

Because the XLA graph compiler aggressively reorders operations, host offloading is
best used in combination with `scan`.

TODO(yifeit): Clean up the example in https://github.com/tengyifei/playground/blob/master/graph_transforms/offloading.py
and put that here.

0 comments on commit e583c2c

Please sign in to comment.