Skip to content

Commit 225c65b

Browse files
scan documentation (#8631)
Co-authored-by: Michael Green <[email protected]>
1 parent 8e6ca60 commit 225c65b

9 files changed

+403
-15
lines changed

docs/source/features/scan.md

+202
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
# Guide for using `scan` and `scan_layers`
2+
3+
This is a guide for using `scan` and `scan_layers` in PyTorch/XLA.
4+
5+
## When should you use this
6+
7+
You should consider using [`scan_layers`][scan_layers] if you have a model with
8+
many homogenous (same shape, same logic) layers, for example LLMs. These models
9+
can be slow to compile. `scan_layers` is a drop-in replacement for a for loop over
10+
homogenous layers, such as a bunch of decoder layers. `scan_layers` traces the
11+
first layer and reuses the compiled result for all subsequent layers, significantly
12+
reducing the model compile time.
13+
14+
[`scan`][scan] on the other hand is a lower level higher-order-op modeled after
15+
[`jax.lax.scan`][jax-lax-scan]. Its primary purpose is to help implement
16+
`scan_layers` under the hood. However, you may find it useful if you would like
17+
to program some sort of loop logic where the loop itself has a first-class
18+
representation in the compiler (specifically, an XLA `While` op).
19+
20+
## `scan_layers` example
21+
22+
Typically, a transformer model passes the input embedding through a sequence of
23+
homogenous decoder layers like the following:
24+
25+
```python
26+
def run_decoder_layers(self, hidden_states):
27+
for decoder_layer in self.layers:
28+
hidden_states = decoder_layer(hidden_states)
29+
return hidden_states
30+
```
31+
32+
When this function is lowered into an HLO graph, the for loop is unrolled into a
33+
flat list of operations, resulting in long compile times. To reduce compile
34+
times, you can replace the for loop with a call to `scan_layers`, as shown in
35+
[`decoder_with_scan.py`][decoder_with_scan]:
36+
37+
```python
38+
def run_decoder_layers(self, hidden_states):
39+
from torch_xla.experimental.scan_layers import scan_layers
40+
return scan_layers(self.layers, hidden_states)
41+
```
42+
43+
You can train this decoder model by running the following command from the root
44+
directory of a `pytorch/xla` source checkout.
45+
46+
```sh
47+
python3 examples/train_decoder_only_base.py scan.decoder_with_scan.DecoderWithScan
48+
```
49+
50+
## `scan` example
51+
52+
[`scan`][scan] takes a combine function and applies that function over the leading
53+
dimension of tensors while carrying along state:
54+
55+
```python
56+
def scan(
57+
fn: Callable[[Carry, X], tuple[Carry, Y]],
58+
init: Carry,
59+
xs: X,
60+
) -> tuple[Carry, Y]:
61+
...
62+
```
63+
64+
You can use it to loop over the leading dimension of tensors efficiently. If `xs`
65+
is a single tensor, this function is roughly equal to the following Python code:
66+
67+
```python
68+
def scan(fn, init, xs):
69+
ys = []
70+
carry = init
71+
for i in len(range(xs.size(0))):
72+
carry, y = fn(carry, xs[i])
73+
ys.append(y)
74+
return carry, torch.stack(ys, dim=0)
75+
```
76+
77+
Under the hood, `scan` is implemented much more efficiently by lowering the loop
78+
into an XLA `While` operation. This ensures that only one iteration of the loop
79+
is compiled by XLA.
80+
81+
[`scan_examples.py`][scan_examples] contains some example code showing how to use
82+
`scan`. In that file, `scan_example_cumsum` uses `scan` to implement a cumulative
83+
sum. `scan_example_pytree` demonstrates how to pass PyTrees to `scan`.
84+
85+
You can run the examples with:
86+
87+
```sh
88+
python3 examples/scan/scan_examples.py
89+
```
90+
91+
The output should look something like the following:
92+
93+
```
94+
Running example: scan_example_cumsum
95+
Final sum: tensor([6.], device='xla:0')
96+
History of sums tensor([[1.],
97+
[3.],
98+
[6.]], device='xla:0')
99+
100+
101+
Running example: scan_example_pytree
102+
Final carry: {'sum': tensor([15.], device='xla:0'), 'count': tensor([5.], device='xla:0')}
103+
Means over time: tensor([[1.0000],
104+
[1.5000],
105+
[2.0000],
106+
[2.5000],
107+
[3.0000]], device='xla:0')
108+
```
109+
110+
## Limitations
111+
112+
### AOTAutograd compatibility requirement
113+
114+
The functions/modules passed to `scan` and `scan_layers` must be AOTAutograd
115+
traceable. In particular, as of PyTorch/XLA 2.6, `scan` and `scan_layers` cannot
116+
trace functions with custom Pallas kernels. That means if your decoder uses,
117+
for example flash attention, then it's incompatible with `scan`. We are working on
118+
[supporting this important use case][flash-attn-issue] in nightly and the next
119+
releases.
120+
121+
### AOTAutograd overhead
122+
123+
Because `scan` uses AOTAutograd to figure out the backward pass of the input
124+
function/module on every iteration, it's easy to become tracing bound compared to
125+
a for loop implementation. In fact, the `train_decoder_only_base.py` example runs
126+
slower under `scan` than with for loop as of PyTorch/XLA 2.6 due to this overhead.
127+
We are working on [improving tracing speed][retracing-issue]. This is less of a
128+
problem when your model is very large or has many layers, which are the situations
129+
you would want to use `scan` anyways.
130+
131+
## Compile time experiments
132+
133+
To demonstrate the compile time savings, we'll train a simple decoder with many
134+
layers on a single TPU chip with for loops vs with `scan_layers`.
135+
136+
- Run the for loop implementation:
137+
138+
```sh
139+
❯ python3 examples/train_decoder_only_base.py \
140+
--hidden-size 256 \
141+
--num-layers 50 \
142+
--num-attention-heads 4 \
143+
--num-key-value-heads 2 \
144+
--intermediate-size 2048 \
145+
--num-steps 5 \
146+
--print-metrics
147+
148+
...
149+
150+
Metric: CompileTime
151+
TotalSamples: 3
152+
Accumulator: 02m57s694ms418.595us
153+
ValueRate: 02s112ms586.097us / second
154+
Rate: 0.054285 / second
155+
Percentiles: 1%=023ms113.470us; 5%=023ms113.470us; 10%=023ms113.470us; 20%=023ms113.470us; 50%=54s644ms733.284us; 80%=01m03s028ms571.841us; 90%=01m03s028ms571.841us; 95%=01m03s028ms571.841us;
156+
99%=01m03s028ms571.841us
157+
```
158+
159+
- Run the `scan_layers` implementation:
160+
161+
```sh
162+
❯ python3 examples/train_decoder_only_base.py \
163+
scan.decoder_with_scan.DecoderWithScan \
164+
--hidden-size 256 \
165+
--num-layers 50 \
166+
--num-attention-heads 4 \
167+
--num-key-value-heads 2 \
168+
--intermediate-size 2048 \
169+
--num-steps 5 \
170+
--print-metrics
171+
172+
...
173+
174+
Metric: CompileTime
175+
TotalSamples: 3
176+
Accumulator: 29s996ms941.409us
177+
ValueRate: 02s529ms591.388us / second
178+
Rate: 0.158152 / second
179+
Percentiles: 1%=018ms636.571us; 5%=018ms636.571us; 10%=018ms636.571us; 20%=018ms636.571us; 50%=11s983ms003.171us; 80%=18s995ms301.667us; 90%=18s995ms301.667us; 95%=18s995ms301.667us;
180+
99%=18s995ms301.667us
181+
```
182+
183+
We can see that the maximum compile time dropped from `1m03s` to `19s` by
184+
switching to `scan_layers`.
185+
186+
## References
187+
188+
See https://github.com/pytorch/xla/issues/7253 for the design of `scan` and
189+
`scan_layers` itself.
190+
191+
See the function doc comments of [`scan`][scan] and [`scan_layers`][scan_layers]
192+
for details on how to use them.
193+
194+
<!-- xrefs -->
195+
196+
[scan]: https://github.com/pytorch/xla/blob/master/torch_xla/experimental/scan.py
197+
[scan_layers]: https://github.com/pytorch/xla/blob/master/torch_xla/experimental/scan_layers.py
198+
[flash-attn-issue]: https://github.com/pytorch/xla/issues/8633
199+
[retracing-issue]: https://github.com/pytorch/xla/issues/8632
200+
[jax-lax-scan]: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html
201+
[decoder_with_scan]: /examples/scan/decoder_with_scan.py
202+
[scan_examples]: /examples/scan/scan_examples.py

docs/source/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ PyTorch/XLA is a Python package that uses the XLA deep learning compiler to conn
4040
features/pallas.md
4141
features/stablehlo.md
4242
features/triton.md
43+
features/scan.md
4344

4445
.. toctree::
4546
:glob:

examples/decoder_only_model.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from dataclasses import dataclass
2-
from typing import Optional
32
import math
43

54
import torch
@@ -201,7 +200,7 @@ def forward(
201200
class DecoderOnlyModel(nn.Module):
202201

203202
def __init__(self, config: DecoderOnlyConfig):
204-
super(DecoderOnlyModel, self).__init__()
203+
super().__init__()
205204
self.vocab_size = config.vocab_size
206205
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
207206
self.layers = nn.ModuleList(
@@ -211,18 +210,22 @@ def __init__(self, config: DecoderOnlyConfig):
211210

212211
def forward(
213212
self,
214-
input_ids: torch.LongTensor = None,
213+
input_ids: torch.LongTensor,
215214
) -> torch.Tensor:
216215
inputs_embeds = self.embed_tokens(input_ids)
217216

218217
# embed positions
219218
hidden_states = inputs_embeds
220219

221220
# decoder layers
222-
for idx, decoder_layer in enumerate(self.layers):
223-
layer_outputs = decoder_layer(hidden_states,)
224-
hidden_states = layer_outputs
221+
hidden_states = self.run_decoder_layers(hidden_states)
225222

226223
hidden_states = self.norm(hidden_states)
224+
227225
# [B, S, H] -> [B, S, V]
228226
return self.output(hidden_states)
227+
228+
def run_decoder_layers(self, hidden_states):
229+
for decoder_layer in self.layers:
230+
hidden_states = decoder_layer(hidden_states)
231+
return hidden_states

examples/scan/README.md

-2
This file was deleted.

examples/scan/README.md

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../../docs/source/features/scan.md

examples/scan/decoder_with_scan.py

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from typing_extensions import override
2+
from decoder_only_model import DecoderOnlyConfig, DecoderOnlyModel
3+
4+
5+
class DecoderWithScan(DecoderOnlyModel):
6+
7+
def __init__(self, config: DecoderOnlyConfig):
8+
super().__init__(config)
9+
10+
@override
11+
def run_decoder_layers(self, hidden_states):
12+
from torch_xla.experimental.scan_layers import scan_layers
13+
return scan_layers(self.layers, hidden_states)

examples/scan/scan_examples.py

+85
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import torch
2+
import torch_xla
3+
4+
from torch_xla.experimental.scan import scan
5+
6+
7+
def scan_example_cumsum():
8+
"""
9+
This example uses the `scan` function to compute the cumulative sum of a tensor.
10+
"""
11+
12+
# 1) Define a combine function that takes in the accumulated sum and the next element,
13+
# and returns the new accumulated sum. We return two values, one is the "carry" that
14+
# will be passed to the next iteration of this function call, and the other is the
15+
# "output" that will be stacked into the final result.
16+
def cumsum(accumulated, element):
17+
accumulated += element
18+
return accumulated, accumulated
19+
20+
# 2) Define an initial carry and the input tensor.
21+
init_sum = torch.tensor([0.0], device=torch_xla.device())
22+
xs = torch.tensor([1.0, 2.0, 3.0], device=torch_xla.device())
23+
torch_xla.sync()
24+
25+
# 3) Call `scan` with our combine function, initial carry, and input tensor.
26+
final, result = scan(cumsum, init_sum, xs)
27+
torch_xla.sync()
28+
29+
print("Final sum:", final)
30+
print("History of sums", result)
31+
32+
33+
def scan_example_pytree():
34+
"""
35+
This example uses the `scan` function to compute a running mean.
36+
37+
It demonstrates using PyTrees as inputs and outputs, in particular, dictionaries.
38+
"""
39+
# 1) Define an initial carry as a dictionary with two leaves:
40+
# - 'sum' to accumulate the sum of all seen values
41+
# - 'count' to count how many values have been seen
42+
carry = {
43+
'sum': torch.tensor([0.0], device=torch_xla.device()),
44+
'count': torch.tensor([0.0], device=torch_xla.device())
45+
}
46+
47+
# 2) Define our input PyTree, which in this case is just a dictionary with one leaf:
48+
# - 'values' is a 1D tensor representing data points we want to scan over.
49+
xs = {
50+
'values':
51+
torch.arange(1, 6, dtype=torch.float32, device=torch_xla.device())
52+
}
53+
54+
# Here, xs['values'] has shape [5]. The `scan` function will automatically slice
55+
# out one element (shape []) each iteration.
56+
57+
# 3) Define our function (akin to a "step" function in jax.lax.scan). It:
58+
# - takes in the current carry and the current slice of xs,
59+
# - updates the sum/count in the carry,
60+
# - computes a new output (the running mean),
61+
# - returns the updated carry and that output.
62+
def fn(carry_dict, x_dict):
63+
new_sum = carry_dict['sum'] + x_dict['values']
64+
new_count = carry_dict['count'] + 1.0
65+
new_carry = {'sum': new_sum, 'count': new_count}
66+
running_mean = new_sum / new_count
67+
return new_carry, running_mean
68+
69+
# 4) Call `scan` with our step function, initial carry, and input dictionary.
70+
final_carry, means_over_time = scan(fn, carry, xs)
71+
72+
# 5) `final_carry` contains the final sum/count, while `means_over_time` is
73+
# a 1D tensor with the running mean at each step.
74+
print("Final carry:", final_carry)
75+
print("Means over time:", means_over_time)
76+
77+
78+
if __name__ == "__main__":
79+
for example in [
80+
scan_example_cumsum,
81+
scan_example_pytree,
82+
]:
83+
print(f"\nRunning example: {example.__name__}", flush=True)
84+
example()
85+
print(flush=True)

0 commit comments

Comments
 (0)