Skip to content

Add cache to value_and_grad_partitioned #9163

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 135 additions & 17 deletions test/scan/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@

import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.experimental.scan as scan_module
from torch_xla.experimental.scan import scan, value_and_grad_partitioned, tree_flatten_none

parent_folder = os.path.dirname(os.path.dirname(__file__))
sys.path.append(parent_folder)
from test_utils import XlaTestCase # type:ignore
from absl.testing import parameterized


def _loopy_scan(fn, init, xs):
Expand Down Expand Up @@ -44,6 +46,8 @@ class TestBase(XlaTestCase):
def setUp(self):
super().setUp()
self.device = torch_xla.device()
# Clear the scan computation cache before each test to avoid cross-test contamination.
scan_module._SCAN_COMPUTATION_CACHE.clear()

def compare_pytree(self, expected_pytree, actual_pytree):
flat_expected_pytree, expected_spec = tree_flatten(expected_pytree)
Expand All @@ -59,13 +63,14 @@ def compare_pytree(self, expected_pytree, actual_pytree):
super().compareResults(flat_expected_pytree, flat_actual_pytree)


class ScanTest(TestBase):
class ScanTest(TestBase, parameterized.TestCase):

def run_test(self,
fn,
init: PyTree,
xs: PyTree,
partition_fn=default_partition):
partition_fn=default_partition,
is_fn_pure: bool = True):
"""Compares the result of scanning with `fn` with our optimized HLO implementation
against a for loop implementation. Checks both output values and gradients.
"""
Expand All @@ -78,7 +83,12 @@ def run_test(self,
# Actual output
init_scan = tree_map(dupe, init)
xs_scan = tree_map(dupe, xs)
final_carry, ys = scan(fn, init_scan, xs_scan, partition_fn=partition_fn)
final_carry, ys = scan(
fn,
init_scan,
xs_scan,
partition_fn=partition_fn,
is_fn_pure=is_fn_pure)
# Add up all leaves and `backward()` once.
(squish(final_carry) + squish(ys)).backward()
torch_xla.sync()
Expand All @@ -105,7 +115,8 @@ def run_test(self,

return final_carry, ys

def test_scan_simple(self):
@parameterized.parameters(True, False)
def test_scan_simple(self, is_fn_pure: bool):
"""This test uses `scan` to implement `torch.cumsum`."""

def step_fn(carry, x):
Expand All @@ -117,7 +128,7 @@ def step_fn(carry, x):
xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
requires_grad=True,
device=self.device)
final_carry, ys = self.run_test(step_fn, init, xs)
final_carry, ys = self.run_test(step_fn, init, xs, is_fn_pure=is_fn_pure)

# Also ensure that our loop-based scan is correct, with manual checks
# that replicate the step_fn.
Expand All @@ -140,7 +151,8 @@ def test_scan_incompatible_length(self):
with self.assertRaises(ValueError):
scan(lambda a, b: (a, b), init, (xs_1, xs_2))

def test_scan_tuples(self):
@parameterized.parameters(True, False)
def test_scan_tuples(self, is_fn_pure: bool):
"""Test scanning over the leading axis of a tuple of tensors simultaneously,
which is a simple PyTree."""

Expand All @@ -163,9 +175,10 @@ def fn(carry, x):
requires_grad=True,
device=self.device))

self.run_test(fn, init, xs)
self.run_test(fn, init, xs, is_fn_pure=is_fn_pure)

def test_scan_create_tensors(self):
@parameterized.parameters(True, False)
def test_scan_create_tensors(self, is_fn_pure: bool):
"""Test scanning over a function that internally creates tensors."""

def fn(carry, x):
Expand All @@ -177,7 +190,7 @@ def fn(carry, x):
xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
requires_grad=True,
device=self.device)
self.run_test(fn, init, xs)
self.run_test(fn, init, xs, is_fn_pure=is_fn_pure)

def test_scan_create_tensors_no_transfers_from_device(self):
"""Test that scanning over a function that internally creates tensors
Expand Down Expand Up @@ -220,7 +233,8 @@ def fn(carry, x):
device=self.device)
self.run_test(fn, init, xs)

def test_scan_input_output_aliases_carry(self):
@parameterized.parameters(True, False)
def test_scan_input_output_aliases_carry(self, is_fn_pure: bool):
"""
Test scan still works when a fn output aliases its carry input.
"""
Expand All @@ -232,9 +246,10 @@ def fn(carry, x):
xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
requires_grad=True,
device=self.device)
self.run_test(fn, init, xs)
self.run_test(fn, init, xs, is_fn_pure=is_fn_pure)

def test_scan_input_output_aliases_x(self):
@parameterized.parameters(True, False)
def test_scan_input_output_aliases_x(self, is_fn_pure: bool):
"""
Test scan still works when a fn output aliases its x input.
"""
Expand All @@ -246,7 +261,7 @@ def fn(carry, x):
xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
requires_grad=True,
device=self.device)
self.run_test(fn, init, xs)
self.run_test(fn, init, xs, is_fn_pure=is_fn_pure)

def test_scan_input_in_place_mutation(self):
"""
Expand Down Expand Up @@ -288,7 +303,8 @@ def step_fn(carry, x):
with self.assertRaisesRegex(AssertionError, "FakeTensor"):
scan(step_fn, init, xs)

def test_scan_gradness(self):
@parameterized.parameters(True, False)
def test_scan_gradness(self, is_fn_pure: bool):
"""
Test the gradient output of `scan` when various inputs require or doesn't
require gradients.
Expand All @@ -307,7 +323,7 @@ def fn(carry, x):
xs = torch.tensor([[2.0, 3.0], [4.0, 5.0], [6.0, 7.0]],
requires_grad=xs_requires_grad,
device=self.device)
self.run_test(fn, init, xs)
self.run_test(fn, init, xs, is_fn_pure=is_fn_pure)

test_case(True, True)
test_case(True, False)
Expand Down Expand Up @@ -445,7 +461,8 @@ def fn(carry, x):
self.assertEqual(bf16_ys.dtype, torch.bfloat16)
self.assertEqual(f32_ys.dtype, torch.float32)

def test_scan_activation_aliases_input(self):
@parameterized.parameters(True, False)
def test_scan_activation_aliases_input(self, is_fn_pure: bool):
"""Test that if an intermediate activation of fn aliases an input,
we directly save the input tensor into the context object, instead of
indexing into the leading dimension during the while loop and copying
Expand All @@ -470,7 +487,7 @@ def unpack(x):

# Intercept the tensors stored in the context object.
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
final_carry, ys = scan(fn, carry, xs)
final_carry, ys = scan(fn, carry, xs, is_fn_pure=is_fn_pure)
ys.sum().backward()
torch_xla.sync()

Expand All @@ -487,6 +504,107 @@ def unpack(x):
# as opposed to just numerically identical but otherwise an extra copy.
assert id(stored_xs) == id(xs)

def test_scan_computation_cache(self):
"""
Test that the computation cache is populated correctly.
"""
fn1_call_count = 0

def fn1(carry, x):
nonlocal fn1_call_count
fn1_call_count += 1
return carry + x, x

init = torch.tensor([0.0, 0.0], device=self.device)
xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
device=self.device,
requires_grad=True)

for _ in range(10):
scan(fn1, init, xs)

cache = scan_module._SCAN_COMPUTATION_CACHE

# Check if my_scan_fn is in the cache
assert fn1 in cache, "fn1 should be in the cache"

# Inspect the second-level cache for my_scan_fn
second_level_cache = cache[fn1]
assert len(second_level_cache) > 0, "Second-level cache should not be empty"

# Check if the number of calls to fn1 is 1.
assert fn1_call_count == 2, \
"fn1 should be called only twice (one for constructing forward graph and one for constructing backward graph), but was called " + str(fn1_call_count)

# You can further inspect the contents of the second-level cache if needed
for key, value in second_level_cache.items():
forward, alias_input, backward = value
# Add assertions or print statements to check the functions
assert callable(forward)
assert callable(alias_input)
assert callable(backward)

def test_scan_computation_cache_by_fn_and_partition_fn(self):
"""
Test that the computation cache is populated by fn and partition_fn.
"""

def fn1(carry, x):
return carry + x, x

def fn2(carry, x):
return carry * x, x

init = torch.tensor([0.0, 0.0], device=self.device)
xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
device=self.device,
requires_grad=True)
scan(fn1, init, xs)
scan(fn2, init, xs)

cache = scan_module._SCAN_COMPUTATION_CACHE

# Check if fn is in the cache
assert fn1 in cache, "fn1 should be in the cache"
assert fn2 in cache, "fn2 should be in the cache"

# Inspect the second-level cache for fn
second_level_cache = cache[fn1]
assert len(
second_level_cache) == 1, "Second-level cache should be exactly 1"

# Inspect the second-level cache for fn
second_level_cache = cache[fn2]
assert len(
second_level_cache) == 1, "Second-level cache should be exactly 1"

# Check if the partition function created a new cache entry
scan(fn1, init, xs, partition_fn=min_cut_rematerialization_partition)
second_level_cache = cache[fn1]
# Inspect the second-level cache for fn2
assert len(second_level_cache
) == 2, "Second-level cache should be exactly 2. Got: " + str(
len(second_level_cache))

def test_scan_computation_cache_disabled_when_fn_is_not_pure(self):
"""
Test that the computation cache is not populated when the function is not pure.
"""

def fn1(carry, x):
return carry + x, x

init = torch.tensor([0.0, 0.0], device=self.device)
xs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
device=self.device,
requires_grad=True)
scan(fn1, init, xs, is_fn_pure=False)

cache = scan_module._SCAN_COMPUTATION_CACHE

# Check if my_scan_fn is in the cache
assert fn1 not in cache, "fn1 should not be in the cache"


class PyTreeTest(TestBase):

Expand Down
44 changes: 42 additions & 2 deletions torch_xla/experimental/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,26 @@
from torch_xla.distributed.spmd.xla_sharding import shard_as
import torch_xla.debug.profiler as xp
import torch_xla.runtime
from weakref import WeakKeyDictionary

Carry = TypeVar('Carry')
X = TypeVar('X')
Y = TypeVar('Y')

# A cache of the forward, alias_input, backward for `scan`. It has a sturcture
# of {fn_ref: {input_key: (forward, alias_input, backward)}}.
# The `fn_ref` is the address of the given function and is weakly referenced.
# The input_key is computed using the shapes, dtypes, and the pytree specs of
# the carry and xs.
_SCAN_COMPUTATION_CACHE = WeakKeyDictionary()


def scan(
fn: Callable[[Carry, X], tuple[Carry, Y]],
init: Carry,
xs: X,
partition_fn=default_partition,
is_fn_pure: bool = True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since fn is not necessarily pure, this optimization should be opt-in.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I set it to true for two reasons: 1. my experience is that most of the fn is pure. 2. if disabled by default, not many people would use the cache. what do you think?

Copy link
Collaborator

@tengyifei tengyifei Jun 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use False for safety. Even if most people don't use it as a start, this will prove very valuable in our internal usages. For example, we could make GRU use scan(is_fn_pure=True). That will get rid of the significant tracing overhead in GRU.

If the scan defaults to assuming the function is pure, then stuff like this will fail in mysterious ways:

foo=False
def combine_fn():
  if foo:
    bar()
  else:
    baz()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated the default to False.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still True as of 7313adc

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the last minute nit: could we s/is_fn_pure/fn_is_pure? fn_is_pure reads more naturally!

# TODO: consider exposing knobs to control the RNG seed used in each `fn` iteration.
) -> tuple[Carry, Y]:
"""Apply a function over leading dimension of tensors while carrying along state.
Expand Down Expand Up @@ -110,6 +119,11 @@ def scan(fn, init, xs):
based activation checkpointing. You may also write your own partitioner to insert any
custom logic such as host offloading of activations.

is_fn_pure: (Optional[bool]) If `fn` is pure, the tracing cache will be enabled. A pure
function always produces the same output for the same input, and it doesn't have any
side effects, meaning it doesn't modify any state outside of itself. Essentially, it's
like a mathematical function that only depends on its input arguments.

Returns:
(carry, ys): A tuple where `carry` is the last carry object returned by `fn`, and
`ys` is a PyTree with the same structure as `xs`, but where the leaves are formed
Expand Down Expand Up @@ -160,7 +174,7 @@ def scan(fn, init, xs):
raise ValueError(f"`xs` {xs} is an empty PyTree.")

forward, alias_input, backward = value_and_grad_partitioned(
fn, init, xs, partition_fn=partition_fn)
fn, init, xs, partition_fn=partition_fn, is_fn_pure=is_fn_pure)
carry, ys = Scan.apply(forward, alias_input, backward, init,
xs) # type: ignore
return carry, ys
Expand All @@ -170,7 +184,8 @@ def value_and_grad_partitioned(
fn: Callable[[Carry, X], tuple[Carry, Y]],
init: Carry,
xs: X,
partition_fn=default_partition) -> tuple[Callable, Callable, Callable]:
partition_fn=default_partition,
is_fn_pure=True) -> tuple[Callable, Callable, Callable]:
"""
Given a user `fn` to be scanned over the leading dimension of the input `xs`
PyTree and an initial carry object `init`, symbolically traces `fn` and
Expand Down Expand Up @@ -213,11 +228,23 @@ def value_and_grad_partitioned(

partition_fn: An optional partitioning function used to partition fn into
forward and backward graphs.

is_fn_pure: (Optional[bool]) If `fn` is pure, the tracing cache will be enabled.

Returns:
A tuple of `(forward, alias_input, backward)`, detailed in the docstring of this function.
"""

# compute the second-level cache key for tracing and generating the forward and backward graphs.
# The key is a tuple of partition_fn's id, the shapes, dtypes, and the pytree specs of the carry and xs.
def compute_second_level_cache_key(carry_pytree, x_pytree):
carry_flat, carry_flat_spec = tree_flatten(carry_pytree)
x_flat, x_flat_spec = tree_flatten(x_pytree)
carry_key = tuple(
(tuple(tensor.shape), tensor.dtype) for tensor in carry_flat)
x_key = tuple((tuple(tensor.shape), tensor.dtype) for tensor in x_flat)
return (id(partition_fn), carry_key, x_key, carry_flat_spec, x_flat_spec)

# Make some fake tensors to trace the user function and obtain the
# forward and backward graphs. Note that the init/carry fake tensor
# always requires grad. That's because even if the user passed in some
Expand All @@ -233,6 +260,12 @@ def make_fake_tensor(v: torch.Tensor, requires_grad=True) -> torch.Tensor:
fake_x_pytree = tree_map(
lambda v: make_fake_tensor(v[0], requires_grad=v.requires_grad), xs)

second_level_cache_key = compute_second_level_cache_key(
fake_carry_pytree, fake_x_pytree)
if is_fn_pure and fn in _SCAN_COMPUTATION_CACHE:
if second_level_cache_key in _SCAN_COMPUTATION_CACHE[fn]:
return _SCAN_COMPUTATION_CACHE[fn][second_level_cache_key]

# If an output of `fn` aliases the input, `aot_function` will handle that
# pair of variables with an epilogue inside its generated autograd.Function
# that we can't access. In other words, the captured graph won't contain
Expand Down Expand Up @@ -328,6 +361,13 @@ def backward(carry, x):
grad_carry, grad_x = unflatten_bwd_out(out)
return grad_carry, grad_x

# Cache the forward and backward graphs for later use.
if is_fn_pure:
if fn not in _SCAN_COMPUTATION_CACHE:
_SCAN_COMPUTATION_CACHE[fn] = {}
_SCAN_COMPUTATION_CACHE[fn][second_level_cache_key] = (forward, alias_input,
backward)

return forward, alias_input, backward


Expand Down