Skip to content

[scan] Avoid re-tracing the combine function on every call #8632

Open
@tengyifei

Description

🚀 Feature

It should be possible to somehow cache the traced graphs in torch_xla.experimental.scan so we don't trace on every call.

Motivation

Today torch_xla.experimental.scan and scan_layers traces the user function with both AOTAutograd (to get the backward) and with LazyTensor (to lower them to HLO). AOTAutograd is very slow and we can easily become tracing bound. For example, python3 examples/train_decoder_only_base.py takes 1min30s but python3 examples/train_decoder_only_base.py scan.decoder_with_scan.DecoderWithScan takes 4min.

Pitch

We could wait for torch.scan to support autograd (c.f. #7901 (comment)) which will take a long time. In the meantime, we can implement some simple caching based on the id of the input function/module.

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions