Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions test/scan/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.experimental.scan as scan_module
from torch_xla.experimental.scan import scan, value_and_grad_partitioned, tree_flatten_none

parent_folder = os.path.dirname(os.path.dirname(__file__))
Expand Down Expand Up @@ -486,6 +487,98 @@ def unpack(x):
# as opposed to just numerically identical but otherwise an extra copy.
assert id(stored_xs) == id(xs)

def test_scan_computation_cache(self):
"""
Test that the computation cache is populated correctly.
"""

def fn1(carry, x):
return carry + x, x

init = torch.tensor([0.0, 0.0], device=self.device)
xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
device=self.device,
requires_grad=True)
scan(fn1, init, xs)

cache = scan_module._SCAN_COMPUTATION_CACHE

# Check if my_scan_fn is in the cache
assert fn1 in cache, "fn1 should be in the cache"

# Inspect the second-level cache for my_scan_fn
second_level_cache = cache[fn1]
assert len(second_level_cache) > 0, "Second-level cache should not be empty"

# You can further inspect the contents of the second-level cache if needed
for key, value in second_level_cache.items():
forward, alias_input, backward = value
# Add assertions or print statements to check the functions
assert callable(forward)
assert callable(alias_input)
assert callable(backward)

def test_scan_computation_cache_by_fn_and_partition_fn(self):
"""
Test that the computation cache is populated by fn and partition_fn.
"""

def fn1(carry, x):
return carry + x, x

def fn2(carry, x):
return carry * x, x

init = torch.tensor([0.0, 0.0], device=self.device)
xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
device=self.device,
requires_grad=True)
scan(fn1, init, xs)
scan(fn2, init, xs)

cache = scan_module._SCAN_COMPUTATION_CACHE

# Check if fn is in the cache
assert fn1 in cache, "fn1 should be in the cache"
assert fn2 in cache, "fn2 should be in the cache"

# Inspect the second-level cache for fn
second_level_cache = cache[fn1]
assert len(
second_level_cache) == 1, "Second-level cache should be exactly 1"

# Inspect the second-level cache for fn
second_level_cache = cache[fn2]
assert len(
second_level_cache) == 1, "Second-level cache should be exactly 1"

# Check if the partition function created a new cache entry
scan(fn1, init, xs, partition_fn=min_cut_rematerialization_partition)
second_level_cache = cache[fn1]
# Inspect the second-level cache for fn2
assert len(second_level_cache
) == 2, "Second-level cache should be exactly 2. Got: " + str(
len(second_level_cache))

def test_scan_computation_cache_disabled_when_fn_is_not_pure(self):
"""
Test that the computation cache is not populated when the function is not pure.
"""

def fn1(carry, x):
return carry + x, x

init = torch.tensor([0.0, 0.0], device=self.device)
xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
device=self.device,
requires_grad=True)
scan(fn1, init, xs, is_fn_pure=False)

cache = scan_module._SCAN_COMPUTATION_CACHE

# Check if my_scan_fn is in the cache
assert fn1 not in cache, "fn1 should not be in the cache"


class PyTreeTest(TestBase):

Expand Down
44 changes: 42 additions & 2 deletions torch_xla/experimental/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,26 @@
from torch_xla.distributed.spmd.xla_sharding import shard_as
import torch_xla.debug.profiler as xp
import torch_xla.runtime
from weakref import WeakKeyDictionary

Carry = TypeVar('Carry')
X = TypeVar('X')
Y = TypeVar('Y')

# A cache of the forward, alias_input, backward for `scan`. It has a sturcture
# of {fn_ref: {input_key: (forward, alias_input, backward)}}.
# The `fn_ref` is the address of the given function and is weakly referenced.
# The input_key is computed using the shapes, dtypes, and the pytree specs of
# the carry and xs.
_SCAN_COMPUTATION_CACHE = WeakKeyDictionary()


def scan(
fn: Callable[[Carry, X], tuple[Carry, Y]],
init: Carry,
xs: X,
partition_fn=default_partition,
is_fn_pure: bool = True,
# TODO: consider exposing knobs to control the RNG seed used in each `fn` iteration.
) -> tuple[Carry, Y]:
"""Apply a function over leading dimension of tensors while carrying along state.
Expand Down Expand Up @@ -110,6 +119,11 @@ def scan(fn, init, xs):
based activation checkpointing. You may also write your own partitioner to insert any
custom logic such as host offloading of activations.

is_fn_pure: (Optional[bool]) If `fn` is pure, the tracing cache will be enabled. A pure
function always produces the same output for the same input, and it doesn't have any
side effects, meaning it doesn't modify any state outside of itself. Essentially, it's
like a mathematical function that only depends on its input arguments.

Returns:
(carry, ys): A tuple where `carry` is the last carry object returned by `fn`, and
`ys` is a PyTree with the same structure as `xs`, but where the leaves are formed
Expand Down Expand Up @@ -160,7 +174,7 @@ def scan(fn, init, xs):
raise ValueError(f"`xs` {xs} is an empty PyTree.")

forward, alias_input, backward = value_and_grad_partitioned(
fn, init, xs, partition_fn=partition_fn)
fn, init, xs, partition_fn=partition_fn, is_fn_pure=is_fn_pure)
carry, ys = Scan.apply(forward, alias_input, backward, init,
xs) # type: ignore
return carry, ys
Expand All @@ -170,7 +184,8 @@ def value_and_grad_partitioned(
fn: Callable[[Carry, X], tuple[Carry, Y]],
init: Carry,
xs: X,
partition_fn=default_partition) -> tuple[Callable, Callable, Callable]:
partition_fn=default_partition,
is_fn_pure=True) -> tuple[Callable, Callable, Callable]:
"""
Given a user `fn` to be scanned over the leading dimension of the input `xs`
PyTree and an initial carry object `init`, symbolically traces `fn` and
Expand Down Expand Up @@ -213,11 +228,23 @@ def value_and_grad_partitioned(

partition_fn: An optional partitioning function used to partition fn into
forward and backward graphs.

is_fn_pure: (Optional[bool]) If `fn` is pure, the tracing cache will be enabled.

Returns:
A tuple of `(forward, alias_input, backward)`, detailed in the docstring of this function.
"""

# compute the second-level cache key for tracing and generating the forward and backward graphs.
# The key is a tuple of partition_fn's id, the shapes, dtypes, and the pytree specs of the carry and xs.
def compute_second_level_cache_key(carry_pytree, x_pytree):
carry_flat, carry_flat_spec = tree_flatten(carry_pytree)
x_flat, x_flat_spec = tree_flatten(x_pytree)
carry_key = tuple(
(tuple(tensor.shape), tensor.dtype) for tensor in carry_flat)
x_key = tuple((tuple(tensor.shape), tensor.dtype) for tensor in x_flat)
return (id(partition_fn), carry_key, x_key, carry_flat_spec, x_flat_spec)

# Make some fake tensors to trace the user function and obtain the
# forward and backward graphs. Note that the init/carry fake tensor
# always requires grad. That's because even if the user passed in some
Expand All @@ -233,6 +260,12 @@ def make_fake_tensor(v: torch.Tensor, requires_grad=True) -> torch.Tensor:
fake_x_pytree = tree_map(
lambda v: make_fake_tensor(v[0], requires_grad=v.requires_grad), xs)

second_level_cache_key = compute_second_level_cache_key(
fake_carry_pytree, fake_x_pytree)
if is_fn_pure and fn in _SCAN_COMPUTATION_CACHE:
if second_level_cache_key in _SCAN_COMPUTATION_CACHE[fn]:
return _SCAN_COMPUTATION_CACHE[fn][second_level_cache_key]

# If an output of `fn` aliases the input, `aot_function` will handle that
# pair of variables with an epilogue inside its generated autograd.Function
# that we can't access. In other words, the captured graph won't contain
Expand Down Expand Up @@ -328,6 +361,13 @@ def backward(carry, x):
grad_carry, grad_x = unflatten_bwd_out(out)
return grad_carry, grad_x

# Cache the forward and backward graphs for later use.
if is_fn_pure:
if fn not in _SCAN_COMPUTATION_CACHE:
_SCAN_COMPUTATION_CACHE[fn] = {}
_SCAN_COMPUTATION_CACHE[fn][second_level_cache_key] = (forward, alias_input,
backward)

return forward, alias_input, backward


Expand Down
Loading