Skip to content

GSPMD + PyTorch Compile + TPU crash #4824

Open
@agemagician

Description

@agemagician

Hi,

I am trying to combine both GSPMD + PyTorch Compile, but it doesn't work.
I took a copy of the test script "test_train_spmd_imagenet.py" and test it in colab, and it started normally. However, after I added the compile line :

device = xm.xla_device()
  model = get_model_property('model_fn')().to(device)

  model = torch.compile(
        model, backend='aot_torchxla_trace_once')

It crashed.

Here is a Colab example to reproduce the results:
https://colab.research.google.com/drive/1KNcBydAfZXLATpSo-CXILxtHJkK8JD-2?usp=sharing

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingdistributedSPMD and other distributed things.

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions