Skip to content

GSPMD + PyTorch Compile + TPU crash #4824

Open
@agemagician

Description

Hi,

I am trying to combine both GSPMD + PyTorch Compile, but it doesn't work.
I took a copy of the test script "test_train_spmd_imagenet.py" and test it in colab, and it started normally. However, after I added the compile line :

device = xm.xla_device()
  model = get_model_property('model_fn')().to(device)

  model = torch.compile(
        model, backend='aot_torchxla_trace_once')

It crashed.

Here is a Colab example to reproduce the results:
https://colab.research.google.com/drive/1KNcBydAfZXLATpSo-CXILxtHJkK8JD-2?usp=sharing

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions