|
| 1 | +# Guide for using `scan` and `scan_layers` |
| 2 | + |
| 3 | +This is a guide for using `scan` and `scan_layers` in PyTorch/XLA. |
| 4 | + |
| 5 | +## When should you use this |
| 6 | + |
| 7 | +You should consider using [`scan_layers`][scan_layers] if you have a model with |
| 8 | +many homogenous (same shape, same logic) layers, for example LLMs. These models |
| 9 | +can be slow to compile. `scan_layers` is a drop-in replacement for a for loop over |
| 10 | +homogenous layers, such as a bunch of decoder layers. `scan_layers` traces the |
| 11 | +first layer and reuses the compiled result for all subsequent layers, significantly |
| 12 | +reducing the model compile time. |
| 13 | + |
| 14 | +[`scan`][scan] on the other hand is a lower level higher-order-op modeled after |
| 15 | +[`jax.lax.scan`][jax-lax-scan]. Its primary purpose is to help implement |
| 16 | +`scan_layers` under the hood. However, you may find it useful if you would like |
| 17 | +to program some sort of loop logic where the loop itself has a first-class |
| 18 | +representation in the compiler (specifically, an XLA `While` op). |
| 19 | + |
| 20 | +## `scan_layers` example |
| 21 | + |
| 22 | +Typically, a transformer model passes the input embedding through a sequence of |
| 23 | +homogenous decoder layers like the following: |
| 24 | + |
| 25 | +```python |
| 26 | +def run_decoder_layers(self, hidden_states): |
| 27 | + for decoder_layer in self.layers: |
| 28 | + hidden_states = decoder_layer(hidden_states) |
| 29 | + return hidden_states |
| 30 | +``` |
| 31 | + |
| 32 | +When this function is lowered into an HLO graph, the for loop is unrolled into a |
| 33 | +flat list of operations, resulting in long compile times. To reduce compile |
| 34 | +times, you can replace the for loop with a call to `scan_layers`, as shown in |
| 35 | +[`decoder_with_scan.py`][decoder_with_scan]: |
| 36 | + |
| 37 | +```python |
| 38 | +def run_decoder_layers(self, hidden_states): |
| 39 | + from torch_xla.experimental.scan_layers import scan_layers |
| 40 | + return scan_layers(self.layers, hidden_states) |
| 41 | +``` |
| 42 | + |
| 43 | +You can train this decoder model by running the following command from the root |
| 44 | +directory of a `pytorch/xla` source checkout. |
| 45 | + |
| 46 | +```sh |
| 47 | +python3 examples/train_decoder_only_base.py scan.decoder_with_scan.DecoderWithScan |
| 48 | +``` |
| 49 | + |
| 50 | +## `scan` example |
| 51 | + |
| 52 | +[`scan`][scan] takes a combine function and applies that function over the leading |
| 53 | +dimension of tensors while carrying along state: |
| 54 | + |
| 55 | +```python |
| 56 | +def scan( |
| 57 | + fn: Callable[[Carry, X], tuple[Carry, Y]], |
| 58 | + init: Carry, |
| 59 | + xs: X, |
| 60 | +) -> tuple[Carry, Y]: |
| 61 | + ... |
| 62 | +``` |
| 63 | + |
| 64 | +You can use it to loop over the leading dimension of tensors efficiently. If `xs` |
| 65 | +is a single tensor, this function is roughly equal to the following Python code: |
| 66 | + |
| 67 | +```python |
| 68 | +def scan(fn, init, xs): |
| 69 | + ys = [] |
| 70 | + carry = init |
| 71 | + for i in len(range(xs.size(0))): |
| 72 | + carry, y = fn(carry, xs[i]) |
| 73 | + ys.append(y) |
| 74 | + return carry, torch.stack(ys, dim=0) |
| 75 | +``` |
| 76 | + |
| 77 | +Under the hood, `scan` is implemented much more efficiently by lowering the loop |
| 78 | +into an XLA `While` operation. This ensures that only one iteration of the loop |
| 79 | +is compiled by XLA. |
| 80 | + |
| 81 | +[`scan_examples.py`][scan_examples] contains some example code showing how to use |
| 82 | +`scan`. In that file, `scan_example_cumsum` uses `scan` to implement a cumulative |
| 83 | +sum. `scan_example_pytree` demonstrates how to pass PyTrees to `scan`. |
| 84 | + |
| 85 | +You can run the examples with: |
| 86 | + |
| 87 | +```sh |
| 88 | +python3 examples/scan/scan_examples.py |
| 89 | +``` |
| 90 | + |
| 91 | +The output should look something like the following: |
| 92 | + |
| 93 | +``` |
| 94 | +Running example: scan_example_cumsum |
| 95 | +Final sum: tensor([6.], device='xla:0') |
| 96 | +History of sums tensor([[1.], |
| 97 | + [3.], |
| 98 | + [6.]], device='xla:0') |
| 99 | +
|
| 100 | +
|
| 101 | +Running example: scan_example_pytree |
| 102 | +Final carry: {'sum': tensor([15.], device='xla:0'), 'count': tensor([5.], device='xla:0')} |
| 103 | +Means over time: tensor([[1.0000], |
| 104 | + [1.5000], |
| 105 | + [2.0000], |
| 106 | + [2.5000], |
| 107 | + [3.0000]], device='xla:0') |
| 108 | +``` |
| 109 | + |
| 110 | +## Limitations |
| 111 | + |
| 112 | +### AOTAutograd compatibility requirement |
| 113 | + |
| 114 | +The functions/modules passed to `scan` and `scan_layers` must be AOTAutograd |
| 115 | +traceable. In particular, as of PyTorch/XLA 2.6, `scan` and `scan_layers` cannot |
| 116 | +trace functions with custom Pallas kernels. That means if your decoder uses, |
| 117 | +for example flash attention, then it's incompatible with `scan`. We are working on |
| 118 | +[supporting this important use case][flash-attn-issue] in nightly and the next |
| 119 | +releases. |
| 120 | + |
| 121 | +### AOTAutograd overhead |
| 122 | + |
| 123 | +Because `scan` uses AOTAutograd to figure out the backward pass of the input |
| 124 | +function/module on every iteration, it's easy to become tracing bound compared to |
| 125 | +a for loop implementation. In fact, the `train_decoder_only_base.py` example runs |
| 126 | +slower under `scan` than with for loop as of PyTorch/XLA 2.6 due to this overhead. |
| 127 | +We are working on [improving tracing speed][retracing-issue]. This is less of a |
| 128 | +problem when your model is very large or has many layers, which are the situations |
| 129 | +you would want to use `scan` anyways. |
| 130 | + |
| 131 | +## Compile time experiments |
| 132 | + |
| 133 | +To demonstrate the compile time savings, we'll train a simple decoder with many |
| 134 | +layers on a single TPU chip with for loops vs with `scan_layers`. |
| 135 | + |
| 136 | +- Run the for loop implementation: |
| 137 | + |
| 138 | +```sh |
| 139 | +❯ python3 examples/train_decoder_only_base.py \ |
| 140 | + --hidden-size 256 \ |
| 141 | + --num-layers 50 \ |
| 142 | + --num-attention-heads 4 \ |
| 143 | + --num-key-value-heads 2 \ |
| 144 | + --intermediate-size 2048 \ |
| 145 | + --num-steps 5 \ |
| 146 | + --print-metrics |
| 147 | + |
| 148 | +... |
| 149 | + |
| 150 | +Metric: CompileTime |
| 151 | + TotalSamples: 3 |
| 152 | + Accumulator: 02m57s694ms418.595us |
| 153 | + ValueRate: 02s112ms586.097us / second |
| 154 | + Rate: 0.054285 / second |
| 155 | + Percentiles: 1%=023ms113.470us; 5%=023ms113.470us; 10%=023ms113.470us; 20%=023ms113.470us; 50%=54s644ms733.284us; 80%=01m03s028ms571.841us; 90%=01m03s028ms571.841us; 95%=01m03s028ms571.841us; |
| 156 | + 99%=01m03s028ms571.841us |
| 157 | +``` |
| 158 | + |
| 159 | +- Run the `scan_layers` implementation: |
| 160 | + |
| 161 | +```sh |
| 162 | +❯ python3 examples/train_decoder_only_base.py \ |
| 163 | + scan.decoder_with_scan.DecoderWithScan \ |
| 164 | + --hidden-size 256 \ |
| 165 | + --num-layers 50 \ |
| 166 | + --num-attention-heads 4 \ |
| 167 | + --num-key-value-heads 2 \ |
| 168 | + --intermediate-size 2048 \ |
| 169 | + --num-steps 5 \ |
| 170 | + --print-metrics |
| 171 | + |
| 172 | +... |
| 173 | + |
| 174 | +Metric: CompileTime |
| 175 | + TotalSamples: 3 |
| 176 | + Accumulator: 29s996ms941.409us |
| 177 | + ValueRate: 02s529ms591.388us / second |
| 178 | + Rate: 0.158152 / second |
| 179 | + Percentiles: 1%=018ms636.571us; 5%=018ms636.571us; 10%=018ms636.571us; 20%=018ms636.571us; 50%=11s983ms003.171us; 80%=18s995ms301.667us; 90%=18s995ms301.667us; 95%=18s995ms301.667us; |
| 180 | + 99%=18s995ms301.667us |
| 181 | +``` |
| 182 | + |
| 183 | +We can see that the maximum compile time dropped from `1m03s` to `19s` by |
| 184 | +switching to `scan_layers`. |
| 185 | + |
| 186 | +## References |
| 187 | + |
| 188 | +See https://github.com/pytorch/xla/issues/7253 for the design of `scan` and |
| 189 | +`scan_layers` itself. |
| 190 | + |
| 191 | +See the function doc comments of [`scan`][scan] and [`scan_layers`][scan_layers] |
| 192 | +for details on how to use them. |
| 193 | + |
| 194 | +<!-- xrefs --> |
| 195 | + |
| 196 | +[scan]: https://github.com/pytorch/xla/blob/master/torch_xla/experimental/scan.py |
| 197 | +[scan_layers]: https://github.com/pytorch/xla/blob/master/torch_xla/experimental/scan_layers.py |
| 198 | +[flash-attn-issue]: https://github.com/pytorch/xla/issues/8633 |
| 199 | +[retracing-issue]: https://github.com/pytorch/xla/issues/8632 |
| 200 | +[jax-lax-scan]: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html |
| 201 | +[decoder_with_scan]: /examples/scan/decoder_with_scan.py |
| 202 | +[scan_examples]: /examples/scan/scan_examples.py |
0 commit comments