-
Notifications
You must be signed in to change notification settings - Fork 534
Add cache to value_and_grad_partitioned
#9163
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -51,17 +51,26 @@ | |
from torch_xla.distributed.spmd.xla_sharding import shard_as | ||
import torch_xla.debug.profiler as xp | ||
import torch_xla.runtime | ||
from weakref import WeakKeyDictionary | ||
|
||
Carry = TypeVar('Carry') | ||
X = TypeVar('X') | ||
Y = TypeVar('Y') | ||
|
||
# A cache of the forward, alias_input, backward for `scan`. It has a sturcture | ||
# of {fn_ref: {input_key: (forward, alias_input, backward)}}. | ||
# The `fn_ref` is the address of the given function and is weakly referenced. | ||
# The input_key is computed using the shapes, dtypes, and the pytree specs of | ||
# the carry and xs. | ||
_SCAN_COMPUTATION_CACHE = WeakKeyDictionary() | ||
|
||
|
||
def scan( | ||
fn: Callable[[Carry, X], tuple[Carry, Y]], | ||
init: Carry, | ||
xs: X, | ||
partition_fn=default_partition, | ||
is_fn_pure: bool = True, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since fn is not necessarily pure, this optimization should be opt-in. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I set it to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's use If the
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. updated the default to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Still There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry for the last minute nit: could we |
||
# TODO: consider exposing knobs to control the RNG seed used in each `fn` iteration. | ||
) -> tuple[Carry, Y]: | ||
"""Apply a function over leading dimension of tensors while carrying along state. | ||
|
@@ -110,6 +119,11 @@ def scan(fn, init, xs): | |
based activation checkpointing. You may also write your own partitioner to insert any | ||
custom logic such as host offloading of activations. | ||
|
||
is_fn_pure: (Optional[bool]) If `fn` is pure, the tracing cache will be enabled. A pure | ||
function always produces the same output for the same input, and it doesn't have any | ||
side effects, meaning it doesn't modify any state outside of itself. Essentially, it's | ||
like a mathematical function that only depends on its input arguments. | ||
|
||
Returns: | ||
(carry, ys): A tuple where `carry` is the last carry object returned by `fn`, and | ||
`ys` is a PyTree with the same structure as `xs`, but where the leaves are formed | ||
|
@@ -160,7 +174,7 @@ def scan(fn, init, xs): | |
raise ValueError(f"`xs` {xs} is an empty PyTree.") | ||
|
||
forward, alias_input, backward = value_and_grad_partitioned( | ||
fn, init, xs, partition_fn=partition_fn) | ||
fn, init, xs, partition_fn=partition_fn, is_fn_pure=is_fn_pure) | ||
carry, ys = Scan.apply(forward, alias_input, backward, init, | ||
xs) # type: ignore | ||
return carry, ys | ||
|
@@ -170,7 +184,8 @@ def value_and_grad_partitioned( | |
fn: Callable[[Carry, X], tuple[Carry, Y]], | ||
init: Carry, | ||
xs: X, | ||
partition_fn=default_partition) -> tuple[Callable, Callable, Callable]: | ||
partition_fn=default_partition, | ||
is_fn_pure=True) -> tuple[Callable, Callable, Callable]: | ||
""" | ||
Given a user `fn` to be scanned over the leading dimension of the input `xs` | ||
PyTree and an initial carry object `init`, symbolically traces `fn` and | ||
|
@@ -213,11 +228,23 @@ def value_and_grad_partitioned( | |
|
||
partition_fn: An optional partitioning function used to partition fn into | ||
forward and backward graphs. | ||
|
||
is_fn_pure: (Optional[bool]) If `fn` is pure, the tracing cache will be enabled. | ||
|
||
Returns: | ||
A tuple of `(forward, alias_input, backward)`, detailed in the docstring of this function. | ||
""" | ||
|
||
# compute the second-level cache key for tracing and generating the forward and backward graphs. | ||
# The key is a tuple of partition_fn's id, the shapes, dtypes, and the pytree specs of the carry and xs. | ||
def compute_second_level_cache_key(carry_pytree, x_pytree): | ||
carry_flat, carry_flat_spec = tree_flatten(carry_pytree) | ||
x_flat, x_flat_spec = tree_flatten(x_pytree) | ||
carry_key = tuple( | ||
(tuple(tensor.shape), tensor.dtype) for tensor in carry_flat) | ||
x_key = tuple((tuple(tensor.shape), tensor.dtype) for tensor in x_flat) | ||
return (id(partition_fn), carry_key, x_key, carry_flat_spec, x_flat_spec) | ||
|
||
# Make some fake tensors to trace the user function and obtain the | ||
# forward and backward graphs. Note that the init/carry fake tensor | ||
# always requires grad. That's because even if the user passed in some | ||
|
@@ -233,6 +260,12 @@ def make_fake_tensor(v: torch.Tensor, requires_grad=True) -> torch.Tensor: | |
fake_x_pytree = tree_map( | ||
lambda v: make_fake_tensor(v[0], requires_grad=v.requires_grad), xs) | ||
|
||
second_level_cache_key = compute_second_level_cache_key( | ||
fake_carry_pytree, fake_x_pytree) | ||
if is_fn_pure and fn in _SCAN_COMPUTATION_CACHE: | ||
if second_level_cache_key in _SCAN_COMPUTATION_CACHE[fn]: | ||
return _SCAN_COMPUTATION_CACHE[fn][second_level_cache_key] | ||
|
||
# If an output of `fn` aliases the input, `aot_function` will handle that | ||
# pair of variables with an epilogue inside its generated autograd.Function | ||
# that we can't access. In other words, the captured graph won't contain | ||
|
@@ -328,6 +361,13 @@ def backward(carry, x): | |
grad_carry, grad_x = unflatten_bwd_out(out) | ||
return grad_carry, grad_x | ||
|
||
# Cache the forward and backward graphs for later use. | ||
if is_fn_pure: | ||
if fn not in _SCAN_COMPUTATION_CACHE: | ||
_SCAN_COMPUTATION_CACHE[fn] = {} | ||
_SCAN_COMPUTATION_CACHE[fn][second_level_cache_key] = (forward, alias_input, | ||
backward) | ||
|
||
return forward, alias_input, backward | ||
|
||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.