Skip to content

doc: update pytorch-on-xla-devices and troubleshoot doc for tensor synchronization issue #9258

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions docs/source/learn/pytorch-on-xla-devices.md
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,72 @@ device is unavailable the load will fail. PyTorch/XLA, like all of
PyTorch, is under active development and this behavior may change in the
future.

### Unexpected Tensor Materialization During AOT (ahead of time) Tracing

While tensor materialization is normal for JIT workflow, it is not expected during traced inference (i.e. [AOT model tracing in AWS Neuron](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-neuronx/programming-guide/inference/trace-vs-xla-lazytensor.html)).
When working with traced inference, developers may encounter tensor materialization, which leads to graphs being compiled based on example input tensor value and unexpected program behavior.
Therefore we need to take advantage of PyTorch/XLA's debugging flags to identify when unexpected tensor materialization happens and make appropriate code changes to avoid tensor materialization.


A common issue occurs when tensor values are evaluated during model compilation (traced inference). Consider this example:
```python
def forward(self, tensor):
if tensor[0] == 1:
return tensor
else:
return tensor * 2
```

While this code can compile and run, it may lead to unexpected behavior because:

* The tensor value is being accessed during tracing (``tensor[0]``).
* The resulting graph becomes fixed based on the tensor value available during tracing
* Developers might incorrectly assume the condition will be evaluated dynamically during inference
* The solution for the code above is to utilize the debugging flags below to catch the issue and modify the code. One example is to feed the flag through model configuration

See the updated code without tensor materialization:
```python
class TestModel(torch.nn.Module):
def __init__(self, flag=1):
super().__init__()
# the flag should be pre-determined based on the model configuration
# it should not be an input of the model during runtime
self.flag = flag

def forward(self, tensor):
if self.flag:
return tensor
else:
return tensor * 2
```


#### Debugging Flags
To help catch tensor materialization issues, PyTorch/XLA provides two useful approaches:

1. Enable warning messages for tensor materialization:
```
import os
os.environ['PT_XLA_DEBUG_LEVEL'] = '2'
```

2. Disable graph execution to catch issues during development:
```
import torch_xla
torch_xla._XLAC._set_allow_execution(False)
```

#### Recommendations

Using these flags during development can help identify potential issues early in the development cycle. The recommended approach is to:

* Use ``PT_XLA_DEBUG_LEVEL=2`` during initial development to identify potential materialization points
* Apply ``_set_allow_execution(False)`` when you want to ensure no tensor materialization occurs during tracing
* When you see warnings or errors related the tensor materialization, look into the code path and make appropriate changes. The example above moved the flag to the `__init__` function which does not depend on the model input during runtime.

For more detailed debugging information, refer to the [XLA troubleshoot](https://github.com/pytorch/xla/blob/master/docs/source/learn/troubleshoot.md#pytorchxla-debugging-tool).


## Compilation Caching

The XLA compiler converts the traced HLO into an executable which runs
Expand Down
22 changes: 14 additions & 8 deletions docs/source/learn/troubleshoot.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,19 +137,25 @@ Execution Analysis: ------------------------------------------------------------
Execution Analysis: ================================================================================
```

Some common causes of Compilation/Executation are 1. User manually call
`torch_xla.sync()`. 2. [Parallel
Some common causes of compilation/executation are
1. User manually calls
`torch_xla.sync()`.
2. [Parallel
loader](https://github.com/pytorch/xla/blob/fe4af0080af07f78ca2b614dd91b71885a3bbbb8/torch_xla/distributed/parallel_loader.py#L49-L51)
call `torch_xla.sync()` for every x (configurable) batch. 3. Exiting a
cals `torch_xla.sync()` for every x (configurable) batch.
3. Exit a
[profiler StepTrace
region](https://github.com/pytorch/xla/blob/fe4af0080af07f78ca2b614dd91b71885a3bbbb8/torch_xla/debug/profiler.py#L165-L171).
4. Dynamo decide to compile/execute the graph. 5. User trying to
access(often due to logging) the value of a tensor before the
4. Dynamo decides to compile/execute the graph.
5. User tries to
access (often due to logging) the value of a tensor before the
`torch_xla.sync()`.
6. User tries to access a tensor value before calling `mark_step`. See [PyTorch on XLA Devices](https://github.com/pytorch/xla/blob/master/docs/source/learn/pytorch-on-xla-devices.md) for more details.

The op executions caused by items 1-4 are expected, and we want to avoid item 5 by
either reducing the frequency of accessing tensor values or manually adding a call to
`torch_xla.sync()` before accessing them.

The execution caused by 1-4 are expected, and we want to avoid 5 by
either reduce the frequency of accessing tensor values or manually add a
`torch_xla.sync()` before accessing.

Users should expect to see this `Compilation Cause` +
`Executation Cause` pairs for first couple steps. After the model
Expand Down